From a131e4d0e5cb9594b709b906200b53d983ea97fc Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Wed, 25 Feb 2026 09:40:57 +0530 Subject: [PATCH] sample gen support sft (#3240) [skip ci] * add:parameters + callback * sft core + logging * indentation fix * logger fix * loger fix in sft * gen sample on eval * lint * deprecation --- src/axolotl/core/builders/causal.py | 6 + src/axolotl/utils/callbacks/generation.py | 84 +++++++++++ src/axolotl/utils/generation/__init__.py | 5 + src/axolotl/utils/generation/sft.py | 174 ++++++++++++++++++++++ src/axolotl/utils/schemas/config.py | 52 +++++-- src/axolotl/utils/schemas/deprecated.py | 23 +++ 6 files changed, 332 insertions(+), 12 deletions(-) create mode 100644 src/axolotl/utils/callbacks/generation.py create mode 100644 src/axolotl/utils/generation/__init__.py create mode 100644 src/axolotl/utils/generation/sft.py diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 7bfc5e874..c238cbbc3 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -122,6 +122,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ColabCallback = colab_inference_post_train_callback(trainer) callbacks.append(ColabCallback(self.cfg)) + if getattr(self.cfg, "generate_samples", False): + from axolotl.utils.callbacks.generation import SFTGenerationCallback + + callbacks.append(SFTGenerationCallback(trainer)) + LOG.info("SFT sample generation enabled") + callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer)) return callbacks diff --git a/src/axolotl/utils/callbacks/generation.py b/src/axolotl/utils/callbacks/generation.py new file mode 100644 index 000000000..439258c8b --- /dev/null +++ b/src/axolotl/utils/callbacks/generation.py @@ -0,0 +1,84 @@ +"""Callback for generating samples during SFT/Pretrain training.""" + +from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState +from transformers.training_args import TrainingArguments + +from axolotl.utils.generation.sft import generate_samples +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +class SFTGenerationCallback(TrainerCallback): + """Callback for generating samples during SFT/Pretrain training.""" + + def __init__(self, trainer): + self.trainer = trainer + + def on_evaluate( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Generate samples at specified intervals.""" + cfg = self.trainer.axolotl_cfg + + if not getattr(cfg, "generate_samples", False): + return + + dataloader = None + try: + if getattr(self.trainer, "eval_dataset", None) is not None: + dataloader = self.trainer.get_eval_dataloader() + LOG.info( + f"Using eval dataloader for generation at step {state.global_step}" + ) + except Exception as e: + LOG.warning(f"Could not get eval dataloader: {e}") + dataloader = None + + if dataloader is None: + dataloader = self.trainer.get_train_dataloader() + LOG.info( + f"Using train dataloader for generation at step {state.global_step}" + ) + + samples = generate_samples( + model=self.trainer.model, + tokenizer=self.trainer.processing_class, + dataloader=dataloader, + num_generation_samples=getattr(cfg, "num_generation_samples", 3), + max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50), + temperature=getattr(cfg, "generation_temperature", 0.7), + top_p=getattr(cfg, "generation_top_p", None), + top_k=getattr(cfg, "generation_top_k", None), + do_sample=getattr(cfg, "generation_do_sample", True), + prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5), + ) + self._log_samples(samples, state.global_step) + + def _log_samples(self, samples: list, step: int): + """Log generated samples to console and W&B.""" + from axolotl.utils.generation.sft import format_generation_for_logging + + for i, sample in enumerate(samples): + console_text, wandb_text = format_generation_for_logging(sample, i, step) + + LOG.info(console_text) + + try: + import wandb + + if wandb.run is not None: + wandb.log( + { + f"samples/sample_{i + 1}": wandb.Html( + f"
{wandb_text}
" + ) + }, + step=step, + ) + except (ImportError, Exception): + pass diff --git a/src/axolotl/utils/generation/__init__.py b/src/axolotl/utils/generation/__init__.py new file mode 100644 index 000000000..7a222d18d --- /dev/null +++ b/src/axolotl/utils/generation/__init__.py @@ -0,0 +1,5 @@ +"""Generation utilities for monitoring during training.""" + +from .sft import format_generation_for_logging, generate_samples + +__all__ = ["generate_samples", "format_generation_for_logging"] diff --git a/src/axolotl/utils/generation/sft.py b/src/axolotl/utils/generation/sft.py new file mode 100644 index 000000000..70fff80c5 --- /dev/null +++ b/src/axolotl/utils/generation/sft.py @@ -0,0 +1,174 @@ +"""Sample generation utilities for SFT/Pretrain training.""" + +from typing import Any, List, Optional + +import torch +from accelerate.utils import extract_model_from_parallel +from colorama import Fore, Style + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def generate_samples( + model: torch.nn.Module, + tokenizer: Any, + dataloader: Any, + num_generation_samples: int = 3, + max_new_tokens: int = 50, + temperature: float = 0.7, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + do_sample: bool = True, + prompt_ratio: float = 0.5, +) -> List[dict]: + """ + Generate samples from the model during training for monitoring. + + Args: + model: The model to generate from + tokenizer: The tokenizer to use for encoding/decoding + dataloader: Dataloader to sample prompts from + num_generation_samples: Number of samples to generate + max_new_tokens: Maximum new tokens to generate + temperature: Sampling temperature (0.0 = greedy) + top_p: Nucleus sampling parameter + top_k: Top-k sampling parameter + do_sample: Whether to use sampling vs greedy decoding + prompt_ratio: Ratio of sequence to use as prompt (0.0-1.0) + + Returns: + List of dicts with 'prompt', 'generated', and 'full_text' keys + """ + unwrapped_model = extract_model_from_parallel(model) + + training = unwrapped_model.training + unwrapped_model.eval() + + device = next(unwrapped_model.parameters()).device + + generations = [] + + try: + with torch.no_grad(): + samples_collected = 0 + + for batch in dataloader: + if samples_collected >= num_generation_samples: + break + + input_ids = batch["input_ids"].to(device) + attention_mask = batch.get("attention_mask") + if attention_mask is not None: + attention_mask = attention_mask.to(device) + batch_size = input_ids.shape[0] + + indices = torch.randperm(batch_size)[ + : num_generation_samples - samples_collected + ] + + for idx in indices: + if samples_collected >= num_generation_samples: + break + + sequence = input_ids[idx] + + if attention_mask is not None: + seq_len = attention_mask[idx].sum().item() + else: + seq_len = sequence.shape[0] + + if seq_len < 5: + continue + + prompt_len = max(1, int(seq_len * prompt_ratio)) + prompt_ids = sequence[:prompt_len].unsqueeze(0) + + try: + generation_config = { + "max_new_tokens": max_new_tokens, + "do_sample": do_sample, + "pad_token_id": tokenizer.pad_token_id + if tokenizer.pad_token_id is not None + else tokenizer.eos_token_id, + } + + if do_sample: + generation_config["temperature"] = temperature + if top_p is not None: + generation_config["top_p"] = top_p + if top_k is not None: + generation_config["top_k"] = top_k + + generated_ids = unwrapped_model.generate( + prompt_ids, **generation_config + ) + + prompt_text = tokenizer.decode( + prompt_ids[0], skip_special_tokens=True + ) + generated_text = tokenizer.decode( + generated_ids[0][prompt_len:], skip_special_tokens=True + ) + full_text = tokenizer.decode( + generated_ids[0], skip_special_tokens=True + ) + + generations.append( + { + "prompt": prompt_text, + "generated": generated_text, + "full_text": full_text, + } + ) + + samples_collected += 1 + + except Exception as e: + LOG.warning(f"Failed to generate sample: {e}", exc_info=True) + continue + + except Exception as e: + LOG.warning(f"Error during sample generation: {e}", exc_info=True) + + if training: + unwrapped_model.train() + else: + unwrapped_model.eval() + + return generations + + +def format_generation_for_logging( + sample: dict, sample_idx: int, step: int +) -> tuple[str, str]: + """ + Format a generation sample for pretty logging. + + Args: + sample: Dict with 'prompt', 'generated', and 'full_text' keys + sample_idx: Index of the sample + step: Current training step + + Returns: + Tuple of (console_text, wandb_text) + """ + console_text = ( + f"\n{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n" + f"{Style.BRIGHT}{Fore.GREEN}Sample {sample_idx + 1} (Step {step}){Style.RESET_ALL}\n" + f"{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n" + f"{Style.BRIGHT}{Fore.YELLOW}[PROMPT]{Style.RESET_ALL}\n{sample['prompt']}\n\n" + f"{Style.BRIGHT}{Fore.MAGENTA}[GENERATED]{Style.RESET_ALL}\n{sample['generated']}\n" + f"{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n" + ) + wandb_text = ( + f"\n{'=' * 80}\n" + f"Sample {sample_idx + 1} (Step {step})\n" + f"{'=' * 80}\n" + f"[PROMPT]\n{sample['prompt']}\n\n" + f"[GENERATED]\n{sample['generated']}\n" + f"{'=' * 80}\n" + ) + + return console_text, wandb_text diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index f1627367b..35b4a6908 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -338,18 +338,6 @@ class AxolotlInputConfig( ) ddp_find_unused_parameters: bool | None = None - eval_table_size: int | None = Field( - default=None, - json_schema_extra={ - "description": "Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0" - }, - ) - eval_max_new_tokens: int | None = Field( - default=None, - json_schema_extra={ - "description": "Total number of tokens generated for predictions sent to wandb. Default is 128" - }, - ) do_causal_lm_eval: bool | None = Field( default=None, json_schema_extra={ @@ -1106,6 +1094,46 @@ class AxolotlInputConfig( "description": "Add plugins to extend the pipeline. See `src/axolotl/integrations` for the available plugins or doc below for more details. https://docs.axolotl.ai/docs/custom_integrations.html" }, ) + generate_samples: bool | None = Field( + default=False, + json_schema_extra={ + "description": "Enable sample generation during training for monitoring" + }, + ) + num_generation_samples: int | None = Field( + default=3, + json_schema_extra={ + "description": "Number of samples to generate at each interval" + }, + ) + generation_max_new_tokens: int | None = Field( + default=50, + json_schema_extra={"description": "Maximum new tokens to generate per sample"}, + ) + generation_temperature: float | None = Field( + default=0.7, + json_schema_extra={ + "description": "Temperature for sample generation (0.0 = greedy)" + }, + ) + generation_top_p: float | None = Field( + default=None, + json_schema_extra={"description": "Nucleus sampling parameter for generation"}, + ) + generation_top_k: int | None = Field( + default=None, + json_schema_extra={"description": "Top-k sampling parameter for generation"}, + ) + generation_prompt_ratio: float | None = Field( + default=0.5, + json_schema_extra={"description": "Ratio of input to use as prompt (0.0-1.0)"}, + ) + generation_do_sample: bool | None = Field( + default=True, + json_schema_extra={ + "description": "Whether to use sampling (vs greedy decoding)" + }, + ) @field_serializer("datasets") def datasets_serializer( diff --git a/src/axolotl/utils/schemas/deprecated.py b/src/axolotl/utils/schemas/deprecated.py index 972fe0ccf..9dfe69264 100644 --- a/src/axolotl/utils/schemas/deprecated.py +++ b/src/axolotl/utils/schemas/deprecated.py @@ -17,6 +17,8 @@ class DeprecatedParameters(BaseModel): noisy_embedding_alpha: float | None = None dpo_beta: float | None = None evaluation_strategy: str | None = None + eval_table_size: int | None = None + eval_max_new_tokens: int | None = None @field_validator("max_packed_sequence_len") @classmethod @@ -55,6 +57,27 @@ class DeprecatedParameters(BaseModel): LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead") return evaluation_strategy + @field_validator("eval_table_size") + @classmethod + def validate_eval_table_size(cls, eval_table_size): + if eval_table_size is not None: + LOG.warning( + "eval_table_size is deprecated and superseded by generate_samples config. " + "Please use generate_samples: true and num_generation_samples instead. " + "The LogPredictionCallback is replaced by the new sample generation feature." + ) + return eval_table_size + + @field_validator("eval_max_new_tokens") + @classmethod + def validate_eval_max_new_tokens(cls, eval_max_new_tokens): + if eval_max_new_tokens is not None: + LOG.warning( + "eval_max_new_tokens is deprecated and superseded by generate_samples config. " + "Please use generation_max_new_tokens instead." + ) + return eval_max_new_tokens + class RemappedParameters(BaseModel): """Parameters that have been remapped to other names"""