sample gen support sft (#3240) [skip ci]
* add:parameters + callback * sft core + logging * indentation fix * logger fix * loger fix in sft * gen sample on eval * lint * deprecation
This commit is contained in:
@@ -122,6 +122,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
ColabCallback = colab_inference_post_train_callback(trainer)
|
ColabCallback = colab_inference_post_train_callback(trainer)
|
||||||
callbacks.append(ColabCallback(self.cfg))
|
callbacks.append(ColabCallback(self.cfg))
|
||||||
|
|
||||||
|
if getattr(self.cfg, "generate_samples", False):
|
||||||
|
from axolotl.utils.callbacks.generation import SFTGenerationCallback
|
||||||
|
|
||||||
|
callbacks.append(SFTGenerationCallback(trainer))
|
||||||
|
LOG.info("SFT sample generation enabled")
|
||||||
|
|
||||||
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
|
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
|
|||||||
84
src/axolotl/utils/callbacks/generation.py
Normal file
84
src/axolotl/utils/callbacks/generation.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""Callback for generating samples during SFT/Pretrain training."""
|
||||||
|
|
||||||
|
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
|
||||||
|
from transformers.training_args import TrainingArguments
|
||||||
|
|
||||||
|
from axolotl.utils.generation.sft import generate_samples
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SFTGenerationCallback(TrainerCallback):
|
||||||
|
"""Callback for generating samples during SFT/Pretrain training."""
|
||||||
|
|
||||||
|
def __init__(self, trainer):
|
||||||
|
self.trainer = trainer
|
||||||
|
|
||||||
|
def on_evaluate(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Generate samples at specified intervals."""
|
||||||
|
cfg = self.trainer.axolotl_cfg
|
||||||
|
|
||||||
|
if not getattr(cfg, "generate_samples", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
dataloader = None
|
||||||
|
try:
|
||||||
|
if getattr(self.trainer, "eval_dataset", None) is not None:
|
||||||
|
dataloader = self.trainer.get_eval_dataloader()
|
||||||
|
LOG.info(
|
||||||
|
f"Using eval dataloader for generation at step {state.global_step}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.warning(f"Could not get eval dataloader: {e}")
|
||||||
|
dataloader = None
|
||||||
|
|
||||||
|
if dataloader is None:
|
||||||
|
dataloader = self.trainer.get_train_dataloader()
|
||||||
|
LOG.info(
|
||||||
|
f"Using train dataloader for generation at step {state.global_step}"
|
||||||
|
)
|
||||||
|
|
||||||
|
samples = generate_samples(
|
||||||
|
model=self.trainer.model,
|
||||||
|
tokenizer=self.trainer.processing_class,
|
||||||
|
dataloader=dataloader,
|
||||||
|
num_generation_samples=getattr(cfg, "num_generation_samples", 3),
|
||||||
|
max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50),
|
||||||
|
temperature=getattr(cfg, "generation_temperature", 0.7),
|
||||||
|
top_p=getattr(cfg, "generation_top_p", None),
|
||||||
|
top_k=getattr(cfg, "generation_top_k", None),
|
||||||
|
do_sample=getattr(cfg, "generation_do_sample", True),
|
||||||
|
prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5),
|
||||||
|
)
|
||||||
|
self._log_samples(samples, state.global_step)
|
||||||
|
|
||||||
|
def _log_samples(self, samples: list, step: int):
|
||||||
|
"""Log generated samples to console and W&B."""
|
||||||
|
from axolotl.utils.generation.sft import format_generation_for_logging
|
||||||
|
|
||||||
|
for i, sample in enumerate(samples):
|
||||||
|
console_text, wandb_text = format_generation_for_logging(sample, i, step)
|
||||||
|
|
||||||
|
LOG.info(console_text)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
if wandb.run is not None:
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
f"samples/sample_{i + 1}": wandb.Html(
|
||||||
|
f"<pre>{wandb_text}</pre>"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
step=step,
|
||||||
|
)
|
||||||
|
except (ImportError, Exception):
|
||||||
|
pass
|
||||||
5
src/axolotl/utils/generation/__init__.py
Normal file
5
src/axolotl/utils/generation/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Generation utilities for monitoring during training."""
|
||||||
|
|
||||||
|
from .sft import format_generation_for_logging, generate_samples
|
||||||
|
|
||||||
|
__all__ = ["generate_samples", "format_generation_for_logging"]
|
||||||
174
src/axolotl/utils/generation/sft.py
Normal file
174
src/axolotl/utils/generation/sft.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
"""Sample generation utilities for SFT/Pretrain training."""
|
||||||
|
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from accelerate.utils import extract_model_from_parallel
|
||||||
|
from colorama import Fore, Style
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_samples(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
tokenizer: Any,
|
||||||
|
dataloader: Any,
|
||||||
|
num_generation_samples: int = 3,
|
||||||
|
max_new_tokens: int = 50,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
do_sample: bool = True,
|
||||||
|
prompt_ratio: float = 0.5,
|
||||||
|
) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Generate samples from the model during training for monitoring.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to generate from
|
||||||
|
tokenizer: The tokenizer to use for encoding/decoding
|
||||||
|
dataloader: Dataloader to sample prompts from
|
||||||
|
num_generation_samples: Number of samples to generate
|
||||||
|
max_new_tokens: Maximum new tokens to generate
|
||||||
|
temperature: Sampling temperature (0.0 = greedy)
|
||||||
|
top_p: Nucleus sampling parameter
|
||||||
|
top_k: Top-k sampling parameter
|
||||||
|
do_sample: Whether to use sampling vs greedy decoding
|
||||||
|
prompt_ratio: Ratio of sequence to use as prompt (0.0-1.0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with 'prompt', 'generated', and 'full_text' keys
|
||||||
|
"""
|
||||||
|
unwrapped_model = extract_model_from_parallel(model)
|
||||||
|
|
||||||
|
training = unwrapped_model.training
|
||||||
|
unwrapped_model.eval()
|
||||||
|
|
||||||
|
device = next(unwrapped_model.parameters()).device
|
||||||
|
|
||||||
|
generations = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
samples_collected = 0
|
||||||
|
|
||||||
|
for batch in dataloader:
|
||||||
|
if samples_collected >= num_generation_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
input_ids = batch["input_ids"].to(device)
|
||||||
|
attention_mask = batch.get("attention_mask")
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(device)
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
|
||||||
|
indices = torch.randperm(batch_size)[
|
||||||
|
: num_generation_samples - samples_collected
|
||||||
|
]
|
||||||
|
|
||||||
|
for idx in indices:
|
||||||
|
if samples_collected >= num_generation_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
sequence = input_ids[idx]
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
seq_len = attention_mask[idx].sum().item()
|
||||||
|
else:
|
||||||
|
seq_len = sequence.shape[0]
|
||||||
|
|
||||||
|
if seq_len < 5:
|
||||||
|
continue
|
||||||
|
|
||||||
|
prompt_len = max(1, int(seq_len * prompt_ratio))
|
||||||
|
prompt_ids = sequence[:prompt_len].unsqueeze(0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
generation_config = {
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"do_sample": do_sample,
|
||||||
|
"pad_token_id": tokenizer.pad_token_id
|
||||||
|
if tokenizer.pad_token_id is not None
|
||||||
|
else tokenizer.eos_token_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if do_sample:
|
||||||
|
generation_config["temperature"] = temperature
|
||||||
|
if top_p is not None:
|
||||||
|
generation_config["top_p"] = top_p
|
||||||
|
if top_k is not None:
|
||||||
|
generation_config["top_k"] = top_k
|
||||||
|
|
||||||
|
generated_ids = unwrapped_model.generate(
|
||||||
|
prompt_ids, **generation_config
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_text = tokenizer.decode(
|
||||||
|
prompt_ids[0], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
generated_text = tokenizer.decode(
|
||||||
|
generated_ids[0][prompt_len:], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
full_text = tokenizer.decode(
|
||||||
|
generated_ids[0], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
generations.append(
|
||||||
|
{
|
||||||
|
"prompt": prompt_text,
|
||||||
|
"generated": generated_text,
|
||||||
|
"full_text": full_text,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
samples_collected += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOG.warning(f"Failed to generate sample: {e}", exc_info=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOG.warning(f"Error during sample generation: {e}", exc_info=True)
|
||||||
|
|
||||||
|
if training:
|
||||||
|
unwrapped_model.train()
|
||||||
|
else:
|
||||||
|
unwrapped_model.eval()
|
||||||
|
|
||||||
|
return generations
|
||||||
|
|
||||||
|
|
||||||
|
def format_generation_for_logging(
|
||||||
|
sample: dict, sample_idx: int, step: int
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Format a generation sample for pretty logging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample: Dict with 'prompt', 'generated', and 'full_text' keys
|
||||||
|
sample_idx: Index of the sample
|
||||||
|
step: Current training step
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (console_text, wandb_text)
|
||||||
|
"""
|
||||||
|
console_text = (
|
||||||
|
f"\n{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n"
|
||||||
|
f"{Style.BRIGHT}{Fore.GREEN}Sample {sample_idx + 1} (Step {step}){Style.RESET_ALL}\n"
|
||||||
|
f"{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n"
|
||||||
|
f"{Style.BRIGHT}{Fore.YELLOW}[PROMPT]{Style.RESET_ALL}\n{sample['prompt']}\n\n"
|
||||||
|
f"{Style.BRIGHT}{Fore.MAGENTA}[GENERATED]{Style.RESET_ALL}\n{sample['generated']}\n"
|
||||||
|
f"{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n"
|
||||||
|
)
|
||||||
|
wandb_text = (
|
||||||
|
f"\n{'=' * 80}\n"
|
||||||
|
f"Sample {sample_idx + 1} (Step {step})\n"
|
||||||
|
f"{'=' * 80}\n"
|
||||||
|
f"[PROMPT]\n{sample['prompt']}\n\n"
|
||||||
|
f"[GENERATED]\n{sample['generated']}\n"
|
||||||
|
f"{'=' * 80}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
return console_text, wandb_text
|
||||||
@@ -338,18 +338,6 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
ddp_find_unused_parameters: bool | None = None
|
ddp_find_unused_parameters: bool | None = None
|
||||||
|
|
||||||
eval_table_size: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
eval_max_new_tokens: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Total number of tokens generated for predictions sent to wandb. Default is 128"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
do_causal_lm_eval: bool | None = Field(
|
do_causal_lm_eval: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -1106,6 +1094,46 @@ class AxolotlInputConfig(
|
|||||||
"description": "Add plugins to extend the pipeline. See `src/axolotl/integrations` for the available plugins or doc below for more details. https://docs.axolotl.ai/docs/custom_integrations.html"
|
"description": "Add plugins to extend the pipeline. See `src/axolotl/integrations` for the available plugins or doc below for more details. https://docs.axolotl.ai/docs/custom_integrations.html"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
generate_samples: bool | None = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Enable sample generation during training for monitoring"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
num_generation_samples: int | None = Field(
|
||||||
|
default=3,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of samples to generate at each interval"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
generation_max_new_tokens: int | None = Field(
|
||||||
|
default=50,
|
||||||
|
json_schema_extra={"description": "Maximum new tokens to generate per sample"},
|
||||||
|
)
|
||||||
|
generation_temperature: float | None = Field(
|
||||||
|
default=0.7,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Temperature for sample generation (0.0 = greedy)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
generation_top_p: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Nucleus sampling parameter for generation"},
|
||||||
|
)
|
||||||
|
generation_top_k: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Top-k sampling parameter for generation"},
|
||||||
|
)
|
||||||
|
generation_prompt_ratio: float | None = Field(
|
||||||
|
default=0.5,
|
||||||
|
json_schema_extra={"description": "Ratio of input to use as prompt (0.0-1.0)"},
|
||||||
|
)
|
||||||
|
generation_do_sample: bool | None = Field(
|
||||||
|
default=True,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Whether to use sampling (vs greedy decoding)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@field_serializer("datasets")
|
@field_serializer("datasets")
|
||||||
def datasets_serializer(
|
def datasets_serializer(
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ class DeprecatedParameters(BaseModel):
|
|||||||
noisy_embedding_alpha: float | None = None
|
noisy_embedding_alpha: float | None = None
|
||||||
dpo_beta: float | None = None
|
dpo_beta: float | None = None
|
||||||
evaluation_strategy: str | None = None
|
evaluation_strategy: str | None = None
|
||||||
|
eval_table_size: int | None = None
|
||||||
|
eval_max_new_tokens: int | None = None
|
||||||
|
|
||||||
@field_validator("max_packed_sequence_len")
|
@field_validator("max_packed_sequence_len")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -55,6 +57,27 @@ class DeprecatedParameters(BaseModel):
|
|||||||
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
|
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
|
||||||
return evaluation_strategy
|
return evaluation_strategy
|
||||||
|
|
||||||
|
@field_validator("eval_table_size")
|
||||||
|
@classmethod
|
||||||
|
def validate_eval_table_size(cls, eval_table_size):
|
||||||
|
if eval_table_size is not None:
|
||||||
|
LOG.warning(
|
||||||
|
"eval_table_size is deprecated and superseded by generate_samples config. "
|
||||||
|
"Please use generate_samples: true and num_generation_samples instead. "
|
||||||
|
"The LogPredictionCallback is replaced by the new sample generation feature."
|
||||||
|
)
|
||||||
|
return eval_table_size
|
||||||
|
|
||||||
|
@field_validator("eval_max_new_tokens")
|
||||||
|
@classmethod
|
||||||
|
def validate_eval_max_new_tokens(cls, eval_max_new_tokens):
|
||||||
|
if eval_max_new_tokens is not None:
|
||||||
|
LOG.warning(
|
||||||
|
"eval_max_new_tokens is deprecated and superseded by generate_samples config. "
|
||||||
|
"Please use generation_max_new_tokens instead."
|
||||||
|
)
|
||||||
|
return eval_max_new_tokens
|
||||||
|
|
||||||
|
|
||||||
class RemappedParameters(BaseModel):
|
class RemappedParameters(BaseModel):
|
||||||
"""Parameters that have been remapped to other names"""
|
"""Parameters that have been remapped to other names"""
|
||||||
|
|||||||
Reference in New Issue
Block a user