From a131e4d0e5cb9594b709b906200b53d983ea97fc Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Wed, 25 Feb 2026 09:40:57 +0530 Subject: [PATCH] sample gen support sft (#3240) [skip ci] * add:parameters + callback * sft core + logging * indentation fix * logger fix * loger fix in sft * gen sample on eval * lint * deprecation --- src/axolotl/core/builders/causal.py | 6 + src/axolotl/utils/callbacks/generation.py | 84 +++++++++++ src/axolotl/utils/generation/__init__.py | 5 + src/axolotl/utils/generation/sft.py | 174 ++++++++++++++++++++++ src/axolotl/utils/schemas/config.py | 52 +++++-- src/axolotl/utils/schemas/deprecated.py | 23 +++ 6 files changed, 332 insertions(+), 12 deletions(-) create mode 100644 src/axolotl/utils/callbacks/generation.py create mode 100644 src/axolotl/utils/generation/__init__.py create mode 100644 src/axolotl/utils/generation/sft.py diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 7bfc5e874..c238cbbc3 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -122,6 +122,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ColabCallback = colab_inference_post_train_callback(trainer) callbacks.append(ColabCallback(self.cfg)) + if getattr(self.cfg, "generate_samples", False): + from axolotl.utils.callbacks.generation import SFTGenerationCallback + + callbacks.append(SFTGenerationCallback(trainer)) + LOG.info("SFT sample generation enabled") + callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer)) return callbacks diff --git a/src/axolotl/utils/callbacks/generation.py b/src/axolotl/utils/callbacks/generation.py new file mode 100644 index 000000000..439258c8b --- /dev/null +++ b/src/axolotl/utils/callbacks/generation.py @@ -0,0 +1,84 @@ +"""Callback for generating samples during SFT/Pretrain training.""" + +from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState +from transformers.training_args import TrainingArguments + +from axolotl.utils.generation.sft import generate_samples +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +class SFTGenerationCallback(TrainerCallback): + """Callback for generating samples during SFT/Pretrain training.""" + + def __init__(self, trainer): + self.trainer = trainer + + def on_evaluate( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Generate samples at specified intervals.""" + cfg = self.trainer.axolotl_cfg + + if not getattr(cfg, "generate_samples", False): + return + + dataloader = None + try: + if getattr(self.trainer, "eval_dataset", None) is not None: + dataloader = self.trainer.get_eval_dataloader() + LOG.info( + f"Using eval dataloader for generation at step {state.global_step}" + ) + except Exception as e: + LOG.warning(f"Could not get eval dataloader: {e}") + dataloader = None + + if dataloader is None: + dataloader = self.trainer.get_train_dataloader() + LOG.info( + f"Using train dataloader for generation at step {state.global_step}" + ) + + samples = generate_samples( + model=self.trainer.model, + tokenizer=self.trainer.processing_class, + dataloader=dataloader, + num_generation_samples=getattr(cfg, "num_generation_samples", 3), + max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50), + temperature=getattr(cfg, "generation_temperature", 0.7), + top_p=getattr(cfg, "generation_top_p", None), + top_k=getattr(cfg, "generation_top_k", None), + do_sample=getattr(cfg, "generation_do_sample", True), + prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5), + ) + self._log_samples(samples, state.global_step) + + def _log_samples(self, samples: list, step: int): + """Log generated samples to console and W&B.""" + from axolotl.utils.generation.sft import format_generation_for_logging + + for i, sample in enumerate(samples): + console_text, wandb_text = format_generation_for_logging(sample, i, step) + + LOG.info(console_text) + + try: + import wandb + + if wandb.run is not None: + wandb.log( + { + f"samples/sample_{i + 1}": wandb.Html( + f"
{wandb_text}"
+ )
+ },
+ step=step,
+ )
+ except (ImportError, Exception):
+ pass
diff --git a/src/axolotl/utils/generation/__init__.py b/src/axolotl/utils/generation/__init__.py
new file mode 100644
index 000000000..7a222d18d
--- /dev/null
+++ b/src/axolotl/utils/generation/__init__.py
@@ -0,0 +1,5 @@
+"""Generation utilities for monitoring during training."""
+
+from .sft import format_generation_for_logging, generate_samples
+
+__all__ = ["generate_samples", "format_generation_for_logging"]
diff --git a/src/axolotl/utils/generation/sft.py b/src/axolotl/utils/generation/sft.py
new file mode 100644
index 000000000..70fff80c5
--- /dev/null
+++ b/src/axolotl/utils/generation/sft.py
@@ -0,0 +1,174 @@
+"""Sample generation utilities for SFT/Pretrain training."""
+
+from typing import Any, List, Optional
+
+import torch
+from accelerate.utils import extract_model_from_parallel
+from colorama import Fore, Style
+
+from axolotl.utils.logging import get_logger
+
+LOG = get_logger(__name__)
+
+
+def generate_samples(
+ model: torch.nn.Module,
+ tokenizer: Any,
+ dataloader: Any,
+ num_generation_samples: int = 3,
+ max_new_tokens: int = 50,
+ temperature: float = 0.7,
+ top_p: Optional[float] = None,
+ top_k: Optional[int] = None,
+ do_sample: bool = True,
+ prompt_ratio: float = 0.5,
+) -> List[dict]:
+ """
+ Generate samples from the model during training for monitoring.
+
+ Args:
+ model: The model to generate from
+ tokenizer: The tokenizer to use for encoding/decoding
+ dataloader: Dataloader to sample prompts from
+ num_generation_samples: Number of samples to generate
+ max_new_tokens: Maximum new tokens to generate
+ temperature: Sampling temperature (0.0 = greedy)
+ top_p: Nucleus sampling parameter
+ top_k: Top-k sampling parameter
+ do_sample: Whether to use sampling vs greedy decoding
+ prompt_ratio: Ratio of sequence to use as prompt (0.0-1.0)
+
+ Returns:
+ List of dicts with 'prompt', 'generated', and 'full_text' keys
+ """
+ unwrapped_model = extract_model_from_parallel(model)
+
+ training = unwrapped_model.training
+ unwrapped_model.eval()
+
+ device = next(unwrapped_model.parameters()).device
+
+ generations = []
+
+ try:
+ with torch.no_grad():
+ samples_collected = 0
+
+ for batch in dataloader:
+ if samples_collected >= num_generation_samples:
+ break
+
+ input_ids = batch["input_ids"].to(device)
+ attention_mask = batch.get("attention_mask")
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(device)
+ batch_size = input_ids.shape[0]
+
+ indices = torch.randperm(batch_size)[
+ : num_generation_samples - samples_collected
+ ]
+
+ for idx in indices:
+ if samples_collected >= num_generation_samples:
+ break
+
+ sequence = input_ids[idx]
+
+ if attention_mask is not None:
+ seq_len = attention_mask[idx].sum().item()
+ else:
+ seq_len = sequence.shape[0]
+
+ if seq_len < 5:
+ continue
+
+ prompt_len = max(1, int(seq_len * prompt_ratio))
+ prompt_ids = sequence[:prompt_len].unsqueeze(0)
+
+ try:
+ generation_config = {
+ "max_new_tokens": max_new_tokens,
+ "do_sample": do_sample,
+ "pad_token_id": tokenizer.pad_token_id
+ if tokenizer.pad_token_id is not None
+ else tokenizer.eos_token_id,
+ }
+
+ if do_sample:
+ generation_config["temperature"] = temperature
+ if top_p is not None:
+ generation_config["top_p"] = top_p
+ if top_k is not None:
+ generation_config["top_k"] = top_k
+
+ generated_ids = unwrapped_model.generate(
+ prompt_ids, **generation_config
+ )
+
+ prompt_text = tokenizer.decode(
+ prompt_ids[0], skip_special_tokens=True
+ )
+ generated_text = tokenizer.decode(
+ generated_ids[0][prompt_len:], skip_special_tokens=True
+ )
+ full_text = tokenizer.decode(
+ generated_ids[0], skip_special_tokens=True
+ )
+
+ generations.append(
+ {
+ "prompt": prompt_text,
+ "generated": generated_text,
+ "full_text": full_text,
+ }
+ )
+
+ samples_collected += 1
+
+ except Exception as e:
+ LOG.warning(f"Failed to generate sample: {e}", exc_info=True)
+ continue
+
+ except Exception as e:
+ LOG.warning(f"Error during sample generation: {e}", exc_info=True)
+
+ if training:
+ unwrapped_model.train()
+ else:
+ unwrapped_model.eval()
+
+ return generations
+
+
+def format_generation_for_logging(
+ sample: dict, sample_idx: int, step: int
+) -> tuple[str, str]:
+ """
+ Format a generation sample for pretty logging.
+
+ Args:
+ sample: Dict with 'prompt', 'generated', and 'full_text' keys
+ sample_idx: Index of the sample
+ step: Current training step
+
+ Returns:
+ Tuple of (console_text, wandb_text)
+ """
+ console_text = (
+ f"\n{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n"
+ f"{Style.BRIGHT}{Fore.GREEN}Sample {sample_idx + 1} (Step {step}){Style.RESET_ALL}\n"
+ f"{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n"
+ f"{Style.BRIGHT}{Fore.YELLOW}[PROMPT]{Style.RESET_ALL}\n{sample['prompt']}\n\n"
+ f"{Style.BRIGHT}{Fore.MAGENTA}[GENERATED]{Style.RESET_ALL}\n{sample['generated']}\n"
+ f"{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n"
+ )
+ wandb_text = (
+ f"\n{'=' * 80}\n"
+ f"Sample {sample_idx + 1} (Step {step})\n"
+ f"{'=' * 80}\n"
+ f"[PROMPT]\n{sample['prompt']}\n\n"
+ f"[GENERATED]\n{sample['generated']}\n"
+ f"{'=' * 80}\n"
+ )
+
+ return console_text, wandb_text
diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py
index f1627367b..35b4a6908 100644
--- a/src/axolotl/utils/schemas/config.py
+++ b/src/axolotl/utils/schemas/config.py
@@ -338,18 +338,6 @@ class AxolotlInputConfig(
)
ddp_find_unused_parameters: bool | None = None
- eval_table_size: int | None = Field(
- default=None,
- json_schema_extra={
- "description": "Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0"
- },
- )
- eval_max_new_tokens: int | None = Field(
- default=None,
- json_schema_extra={
- "description": "Total number of tokens generated for predictions sent to wandb. Default is 128"
- },
- )
do_causal_lm_eval: bool | None = Field(
default=None,
json_schema_extra={
@@ -1106,6 +1094,46 @@ class AxolotlInputConfig(
"description": "Add plugins to extend the pipeline. See `src/axolotl/integrations` for the available plugins or doc below for more details. https://docs.axolotl.ai/docs/custom_integrations.html"
},
)
+ generate_samples: bool | None = Field(
+ default=False,
+ json_schema_extra={
+ "description": "Enable sample generation during training for monitoring"
+ },
+ )
+ num_generation_samples: int | None = Field(
+ default=3,
+ json_schema_extra={
+ "description": "Number of samples to generate at each interval"
+ },
+ )
+ generation_max_new_tokens: int | None = Field(
+ default=50,
+ json_schema_extra={"description": "Maximum new tokens to generate per sample"},
+ )
+ generation_temperature: float | None = Field(
+ default=0.7,
+ json_schema_extra={
+ "description": "Temperature for sample generation (0.0 = greedy)"
+ },
+ )
+ generation_top_p: float | None = Field(
+ default=None,
+ json_schema_extra={"description": "Nucleus sampling parameter for generation"},
+ )
+ generation_top_k: int | None = Field(
+ default=None,
+ json_schema_extra={"description": "Top-k sampling parameter for generation"},
+ )
+ generation_prompt_ratio: float | None = Field(
+ default=0.5,
+ json_schema_extra={"description": "Ratio of input to use as prompt (0.0-1.0)"},
+ )
+ generation_do_sample: bool | None = Field(
+ default=True,
+ json_schema_extra={
+ "description": "Whether to use sampling (vs greedy decoding)"
+ },
+ )
@field_serializer("datasets")
def datasets_serializer(
diff --git a/src/axolotl/utils/schemas/deprecated.py b/src/axolotl/utils/schemas/deprecated.py
index 972fe0ccf..9dfe69264 100644
--- a/src/axolotl/utils/schemas/deprecated.py
+++ b/src/axolotl/utils/schemas/deprecated.py
@@ -17,6 +17,8 @@ class DeprecatedParameters(BaseModel):
noisy_embedding_alpha: float | None = None
dpo_beta: float | None = None
evaluation_strategy: str | None = None
+ eval_table_size: int | None = None
+ eval_max_new_tokens: int | None = None
@field_validator("max_packed_sequence_len")
@classmethod
@@ -55,6 +57,27 @@ class DeprecatedParameters(BaseModel):
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
return evaluation_strategy
+ @field_validator("eval_table_size")
+ @classmethod
+ def validate_eval_table_size(cls, eval_table_size):
+ if eval_table_size is not None:
+ LOG.warning(
+ "eval_table_size is deprecated and superseded by generate_samples config. "
+ "Please use generate_samples: true and num_generation_samples instead. "
+ "The LogPredictionCallback is replaced by the new sample generation feature."
+ )
+ return eval_table_size
+
+ @field_validator("eval_max_new_tokens")
+ @classmethod
+ def validate_eval_max_new_tokens(cls, eval_max_new_tokens):
+ if eval_max_new_tokens is not None:
+ LOG.warning(
+ "eval_max_new_tokens is deprecated and superseded by generate_samples config. "
+ "Please use generation_max_new_tokens instead."
+ )
+ return eval_max_new_tokens
+
class RemappedParameters(BaseModel):
"""Parameters that have been remapped to other names"""