RL datasets: warn and drop unsalvageable over-length prompts post-truncate; add post-truncate filter; support alias config key 'excess_token_handling'
This commit is contained in:
57
src/axolotl/utils/callbacks/orph3.yml
Normal file
57
src/axolotl/utils/callbacks/orph3.yml
Normal file
@@ -0,0 +1,57 @@
|
||||
base_model: syvai/tts-v1-pretrained
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
hub_model_id: syvai/tts-v0.3-finetuned
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
datasets:
|
||||
- path: syvai/zac-coral-tts
|
||||
type:
|
||||
- path: syvai/zac-dk-voice-pro
|
||||
type:
|
||||
- path: syvai/zac-dk-voice-single-speaker
|
||||
type:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
eval_sample_packing: False
|
||||
output_dir: ./outputs/finetuned
|
||||
|
||||
sequence_len: 8196
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project: orph
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 16
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 3
|
||||
evals_per_epoch: 5
|
||||
saves_per_epoch: 5
|
||||
weight_decay: 0.05
|
||||
|
||||
special_tokens:
|
||||
pad_token: <custom_token_7>
|
||||
@@ -261,8 +261,13 @@ def encode_packed_pretraining(
|
||||
# workaround by using the position id logic for now in trainer
|
||||
drop_attention_mask=multipack_attn,
|
||||
# pass through handling mode from config via ds_wrapper function
|
||||
handling=getattr(ds_wrapper, "cfg", {}).get(
|
||||
"sequence_len_overflow_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
|
||||
handling=(
|
||||
getattr(ds_wrapper, "cfg", {}).get(
|
||||
"sequence_len_overflow_handling",
|
||||
getattr(ds_wrapper, "cfg", {}).get(
|
||||
"excess_token_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -116,28 +116,14 @@ def drop_long_rl_seq(
|
||||
max_response_len = sequence_len - len_prompt
|
||||
|
||||
if max_response_len <= 0:
|
||||
# Prompt is already too long, behavior depends on handling
|
||||
# If truncate is chosen, we technically can't truncate, but drop seems harsh.
|
||||
# Returning the sample might be unexpected. Let's stick to the filter logic
|
||||
# which would drop this in the `filter` step later if needed.
|
||||
# For now, return sample to map, or False to filter.
|
||||
# Let's simplify: truncate *should* result in a valid sample if possible.
|
||||
# If prompt >= seq_len, truncate won't work. Filter will catch this later.
|
||||
# So, if max_response_len <= 0, we pass it through for map, drop for filter.
|
||||
# However, the filter/map logic is applied *after* this function.
|
||||
# This function needs to return the *modified* sample for map, or bool for filter.
|
||||
|
||||
# Re-think: If handling==truncate, return the modified sample if possible.
|
||||
# If prompt >= seq_len, modification is impossible. What should map return?
|
||||
# Maybe return the original sample? But map expects *modified* sample.
|
||||
# Let's stick to the original logic: if prompt is too long, return False for filter
|
||||
# and original sample for map.
|
||||
|
||||
result = (
|
||||
sample # For map, let downstream handle it if still invalid?
|
||||
# Prompt itself exceeds sequence length. Cannot truncate responses to fix it.
|
||||
# Keep sample shape for map(), but log a warning. A subsequent filter will drop it.
|
||||
LOG.warning(
|
||||
"Prompt length (%s) exceeds sequence length (%s) for DPO-like sample; will be dropped post-truncation",
|
||||
len_prompt,
|
||||
sequence_len,
|
||||
)
|
||||
# Or maybe return None/empty dict? Let's return sample for now.
|
||||
# If handling was drop, filter would remove this.
|
||||
result = sample
|
||||
|
||||
else:
|
||||
# Truncate the chosen and rejected responses if needed
|
||||
@@ -184,7 +170,12 @@ def drop_long_rl_seq(
|
||||
max_completion_len = sequence_len - len_prompt
|
||||
|
||||
if max_completion_len <= 0:
|
||||
# Prompt too long, return sample for map
|
||||
# Prompt itself exceeds sequence length. Cannot truncate completion to fix it.
|
||||
LOG.warning(
|
||||
"Prompt length (%s) exceeds sequence length (%s) for KTO sample; will be dropped post-truncation",
|
||||
len_prompt,
|
||||
sequence_len,
|
||||
)
|
||||
result = sample
|
||||
else:
|
||||
# Truncate the completion if needed
|
||||
@@ -211,6 +202,41 @@ def drop_long_rl_seq(
|
||||
|
||||
|
||||
def load_prepare_preference_datasets(cfg):
|
||||
def _is_rl_seq_within_sequence_len(sample, rl, tokenizer, sequence_len):
|
||||
"""
|
||||
Boolean predicate to check whether a preference-learning sample fits within sequence_len.
|
||||
Used with dataset.filter() after truncation to drop unsalvageable samples.
|
||||
"""
|
||||
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
|
||||
if not (
|
||||
sample.get("prompt")
|
||||
and sample.get("chosen")
|
||||
and sample.get("rejected")
|
||||
):
|
||||
return False
|
||||
prompt = sample["prompt"]
|
||||
chosen = sample["chosen"]
|
||||
rejected = sample["rejected"]
|
||||
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
|
||||
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
|
||||
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
||||
return (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len
|
||||
if rl == RLType.KTO:
|
||||
if not (sample.get("prompt") and sample.get("completion")):
|
||||
return False
|
||||
prompt = sample["prompt"]
|
||||
completion = sample["completion"]
|
||||
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
|
||||
len_completion = len(
|
||||
tokenizer(completion, add_special_tokens=False)["input_ids"]
|
||||
)
|
||||
return (len_prompt + len_completion) <= sequence_len
|
||||
if rl == RLType.GRPO:
|
||||
# GRPO does not enforce this check here
|
||||
return True
|
||||
return False
|
||||
def load_split(dataset_cfgs, _cfg):
|
||||
split_datasets: List[Any] = []
|
||||
use_auth_token = _cfg.hf_use_auth_token
|
||||
@@ -255,7 +281,11 @@ def load_prepare_preference_datasets(cfg):
|
||||
|
||||
if not cfg.skip_prepare_dataset:
|
||||
# Determine handling mode
|
||||
handling = cfg.get("sequence_len_overflow_handling", "drop")
|
||||
# Support legacy alias "excess_token_handling" for compatibility
|
||||
handling = cfg.get(
|
||||
"sequence_len_overflow_handling",
|
||||
cfg.get("excess_token_handling", "drop"),
|
||||
)
|
||||
|
||||
drop_long = partial(
|
||||
drop_long_rl_seq,
|
||||
@@ -275,7 +305,18 @@ def load_prepare_preference_datasets(cfg):
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Truncating Long Sequences",
|
||||
)
|
||||
# Note: Length might not change if truncation always occurs
|
||||
# After truncation, drop any samples that still exceed sequence_len (e.g., prompt alone too long)
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
partial(
|
||||
_is_rl_seq_within_sequence_len,
|
||||
rl=_cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
),
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Oversize Samples After Truncation",
|
||||
)
|
||||
LOG.info(
|
||||
f"Processed dataset index {i} with truncation handling for sequence length {cfg.sequence_len}"
|
||||
)
|
||||
|
||||
@@ -168,7 +168,11 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
||||
return dataset
|
||||
|
||||
# Get the handling method from config, default to "drop" for backward compatibility
|
||||
handling = cfg.get("sequence_len_overflow_handling", "drop")
|
||||
# Support legacy alias "excess_token_handling" as well
|
||||
handling = cfg.get(
|
||||
"sequence_len_overflow_handling",
|
||||
cfg.get("excess_token_handling", "drop"),
|
||||
)
|
||||
|
||||
# Use the new function with the specified handling mode
|
||||
seq_handler = functools.partial(
|
||||
|
||||
Reference in New Issue
Block a user