RL datasets: warn and drop unsalvageable over-length prompts post-truncate; add post-truncate filter; support alias config key 'excess_token_handling'

This commit is contained in:
mhenrhcsen
2025-08-12 20:37:41 +02:00
parent 618b008e36
commit f5a3e3529e
8 changed files with 844 additions and 27 deletions

View File

@@ -0,0 +1,57 @@
base_model: syvai/tts-v1-pretrained
# Automatically upload checkpoint and final model to HF
hub_model_id: syvai/tts-v0.3-finetuned
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true
datasets:
- path: syvai/zac-coral-tts
type:
- path: syvai/zac-dk-voice-pro
type:
- path: syvai/zac-dk-voice-single-speaker
type:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
eval_sample_packing: False
output_dir: ./outputs/finetuned
sequence_len: 8196
sample_packing: true
pad_to_sequence_len: true
wandb_project: orph
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 16
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: auto
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 3
evals_per_epoch: 5
saves_per_epoch: 5
weight_decay: 0.05
special_tokens:
pad_token: <custom_token_7>

View File

@@ -261,8 +261,13 @@ def encode_packed_pretraining(
# workaround by using the position id logic for now in trainer
drop_attention_mask=multipack_attn,
# pass through handling mode from config via ds_wrapper function
handling=getattr(ds_wrapper, "cfg", {}).get(
"sequence_len_overflow_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
handling=(
getattr(ds_wrapper, "cfg", {}).get(
"sequence_len_overflow_handling",
getattr(ds_wrapper, "cfg", {}).get(
"excess_token_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
),
)
),
)

View File

@@ -116,28 +116,14 @@ def drop_long_rl_seq(
max_response_len = sequence_len - len_prompt
if max_response_len <= 0:
# Prompt is already too long, behavior depends on handling
# If truncate is chosen, we technically can't truncate, but drop seems harsh.
# Returning the sample might be unexpected. Let's stick to the filter logic
# which would drop this in the `filter` step later if needed.
# For now, return sample to map, or False to filter.
# Let's simplify: truncate *should* result in a valid sample if possible.
# If prompt >= seq_len, truncate won't work. Filter will catch this later.
# So, if max_response_len <= 0, we pass it through for map, drop for filter.
# However, the filter/map logic is applied *after* this function.
# This function needs to return the *modified* sample for map, or bool for filter.
# Re-think: If handling==truncate, return the modified sample if possible.
# If prompt >= seq_len, modification is impossible. What should map return?
# Maybe return the original sample? But map expects *modified* sample.
# Let's stick to the original logic: if prompt is too long, return False for filter
# and original sample for map.
result = (
sample # For map, let downstream handle it if still invalid?
# Prompt itself exceeds sequence length. Cannot truncate responses to fix it.
# Keep sample shape for map(), but log a warning. A subsequent filter will drop it.
LOG.warning(
"Prompt length (%s) exceeds sequence length (%s) for DPO-like sample; will be dropped post-truncation",
len_prompt,
sequence_len,
)
# Or maybe return None/empty dict? Let's return sample for now.
# If handling was drop, filter would remove this.
result = sample
else:
# Truncate the chosen and rejected responses if needed
@@ -184,7 +170,12 @@ def drop_long_rl_seq(
max_completion_len = sequence_len - len_prompt
if max_completion_len <= 0:
# Prompt too long, return sample for map
# Prompt itself exceeds sequence length. Cannot truncate completion to fix it.
LOG.warning(
"Prompt length (%s) exceeds sequence length (%s) for KTO sample; will be dropped post-truncation",
len_prompt,
sequence_len,
)
result = sample
else:
# Truncate the completion if needed
@@ -211,6 +202,41 @@ def drop_long_rl_seq(
def load_prepare_preference_datasets(cfg):
def _is_rl_seq_within_sequence_len(sample, rl, tokenizer, sequence_len):
"""
Boolean predicate to check whether a preference-learning sample fits within sequence_len.
Used with dataset.filter() after truncation to drop unsalvageable samples.
"""
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
if not (
sample.get("prompt")
and sample.get("chosen")
and sample.get("rejected")
):
return False
prompt = sample["prompt"]
chosen = sample["chosen"]
rejected = sample["rejected"]
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
return (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
if rl == RLType.KTO:
if not (sample.get("prompt") and sample.get("completion")):
return False
prompt = sample["prompt"]
completion = sample["completion"]
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
len_completion = len(
tokenizer(completion, add_special_tokens=False)["input_ids"]
)
return (len_prompt + len_completion) <= sequence_len
if rl == RLType.GRPO:
# GRPO does not enforce this check here
return True
return False
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
use_auth_token = _cfg.hf_use_auth_token
@@ -255,7 +281,11 @@ def load_prepare_preference_datasets(cfg):
if not cfg.skip_prepare_dataset:
# Determine handling mode
handling = cfg.get("sequence_len_overflow_handling", "drop")
# Support legacy alias "excess_token_handling" for compatibility
handling = cfg.get(
"sequence_len_overflow_handling",
cfg.get("excess_token_handling", "drop"),
)
drop_long = partial(
drop_long_rl_seq,
@@ -275,7 +305,18 @@ def load_prepare_preference_datasets(cfg):
load_from_cache_file=not cfg.is_preprocess,
desc="Truncating Long Sequences",
)
# Note: Length might not change if truncation always occurs
# After truncation, drop any samples that still exceed sequence_len (e.g., prompt alone too long)
split_datasets[i] = split_datasets[i].filter(
partial(
_is_rl_seq_within_sequence_len,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
),
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Oversize Samples After Truncation",
)
LOG.info(
f"Processed dataset index {i} with truncation handling for sequence length {cfg.sequence_len}"
)

View File

@@ -168,7 +168,11 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
return dataset
# Get the handling method from config, default to "drop" for backward compatibility
handling = cfg.get("sequence_len_overflow_handling", "drop")
# Support legacy alias "excess_token_handling" as well
handling = cfg.get(
"sequence_len_overflow_handling",
cfg.get("excess_token_handling", "drop"),
)
# Use the new function with the specified handling mode
seq_handler = functools.partial(