From 17a5838d38aa61dd1fae899f1d03749a8684a0d9 Mon Sep 17 00:00:00 2001 From: mhenrhcsen Date: Mon, 12 May 2025 14:36:43 +0200 Subject: [PATCH] lint --- docs/config.qmd | 2 +- src/axolotl/utils/data/rl.py | 50 ++++++++++++++++---------- src/axolotl/utils/data/utils.py | 2 +- src/axolotl/utils/schemas/config.py | 4 ++- src/axolotl/utils/trainer.py | 56 +++++++++++++++++------------ 5 files changed, 71 insertions(+), 43 deletions(-) diff --git a/docs/config.qmd b/docs/config.qmd index 227797fec..b12d36cf9 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -332,7 +332,7 @@ dataset_shard_idx: # The maximum length of an input to train with, this should typically be less than 2048 # as most models have a token/context limit of 2048 sequence_len: 2048 -# How to handle tokens exceeding max sequence length - "drop" (default, removes sample) or "truncate" (cuts off excess tokens) +# How to handle tokens exceeding max sequence length - "drop" (default, removes sample) or "truncate" (cuts off excess tokens) excess_token_handling: drop # Pad inputs so each step uses constant sized buffers # This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 052dff43a..84a47e85b 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -106,28 +106,36 @@ def drop_long_rl_seq( len_prompt + len_rejected ) <= sequence_len: return sample - + # For truncation, we need to truncate the chosen and rejected responses # to fit within sequence_len, but preserve the prompt - + # Calculate maximum response length that can fit with the prompt max_response_len = sequence_len - len_prompt - + if max_response_len <= 0: # Prompt is already too long, we can't truncate effectively return False if handling == "drop" else sample - + # Truncate the chosen and rejected responses if needed if len_chosen > max_response_len: # Tokenize, truncate, and decode - chosen_tokens = tokenizer(chosen, add_special_tokens=False)["input_ids"][:max_response_len] - sample["chosen"] = tokenizer.decode(chosen_tokens, skip_special_tokens=True) - + chosen_tokens = tokenizer(chosen, add_special_tokens=False)[ + "input_ids" + ][:max_response_len] + sample["chosen"] = tokenizer.decode( + chosen_tokens, skip_special_tokens=True + ) + if len_rejected > max_response_len: # Tokenize, truncate, and decode - rejected_tokens = tokenizer(rejected, add_special_tokens=False)["input_ids"][:max_response_len] - sample["rejected"] = tokenizer.decode(rejected_tokens, skip_special_tokens=True) - + rejected_tokens = tokenizer(rejected, add_special_tokens=False)[ + "input_ids" + ][:max_response_len] + sample["rejected"] = tokenizer.decode( + rejected_tokens, skip_special_tokens=True + ) + return sample if rl == "kto": @@ -148,20 +156,24 @@ def drop_long_rl_seq( # If sequence fits, return sample unchanged if (len_prompt + len_completion) <= sequence_len: return sample - + # Calculate maximum completion length that can fit with the prompt max_completion_len = sequence_len - len_prompt - + if max_completion_len <= 0: # Prompt is already too long, we can't truncate effectively return False if handling == "drop" else sample - + # Truncate the completion if needed if len_completion > max_completion_len: # Tokenize, truncate, and decode - completion_tokens = tokenizer(completion, add_special_tokens=False)["input_ids"][:max_completion_len] - sample["completion"] = tokenizer.decode(completion_tokens, skip_special_tokens=True) - + completion_tokens = tokenizer(completion, add_special_tokens=False)[ + "input_ids" + ][:max_completion_len] + sample["completion"] = tokenizer.decode( + completion_tokens, skip_special_tokens=True + ) + return sample if rl == "grpo": @@ -223,7 +235,7 @@ def load_prepare_preference_datasets(cfg): ) prior_len = len(split_datasets[i]) - + # Use filter for drop mode and map for truncate mode handling = cfg.get("excess_token_handling", "drop") if handling == "drop": @@ -245,7 +257,9 @@ def load_prepare_preference_datasets(cfg): load_from_cache_file=not cfg.is_preprocess, desc="Truncating Long Sequences", ) - LOG.info(f"Truncated long samples in dataset index {i} to {cfg.sequence_len} tokens") + LOG.info( + f"Truncated long samples in dataset index {i} to {cfg.sequence_len} tokens" + ) combined_datasets = concatenate_datasets(split_datasets) combined_datasets = combined_datasets.shuffle(seed=cfg.seed) diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 8a07af51e..b83e92b47 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -167,7 +167,7 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): # Get the handling method from config, default to "drop" for backward compatibility handling = cfg.get("excess_token_handling", "drop") - + if handling == "drop": # Use the existing drop_long_seq function for backward compatibility seq_handler = functools.partial( diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 843a659c9..0c14761be 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -188,7 +188,9 @@ class AxolotlInputConfig( sequence_len: int = Field(default=512) excess_token_handling: Literal["drop", "truncate"] = Field( default="drop", - json_schema_extra={"description": "how to handle tokens exceeding max sequence length - drop the sample or truncate"}, + json_schema_extra={ + "description": "how to handle tokens exceeding max sequence length - drop the sample or truncate" + }, ) min_sample_len: int | None = None max_prompt_len: int = Field( diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 9e1576663..f9c134e82 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -235,7 +235,9 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): return results -def truncate_or_drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, handling="drop"): +def truncate_or_drop_long_seq( + sample, sequence_len=2048, min_sequence_len=2, handling="drop" +): """ Either drop or truncate samples whose sequence length is either too long (> sequence_len) or too short (< min_sequence_len). @@ -264,35 +266,37 @@ def truncate_or_drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, han if isinstance(input_ids[0], int): # Single example (input_ids is a list of int) length = len(input_ids) - + # Handle samples that are too short - always drop them if length < min_sequence_len: return False if handling == "drop" else sample - + # If truncation is enabled and the sample is too long, truncate it if length > sequence_len and handling == "truncate": sample["input_ids"] = input_ids[:sequence_len] - + # Also truncate attention_mask if present if "attention_mask" in sample: sample["attention_mask"] = sample["attention_mask"][:sequence_len] - + # Also truncate labels if present if "labels" in sample: sample["labels"] = sample["labels"][:sequence_len] - + # Also truncate position_ids if present if "position_ids" in sample: sample["position_ids"] = sample["position_ids"][:sequence_len] - + # Update length if present if "length" in sample: sample["length"] = sequence_len - + return sample - + # For drop mode or if the sample doesn't exceed max length - return min_sequence_len <= length <= sequence_len if handling == "drop" else sample + return ( + min_sequence_len <= length <= sequence_len if handling == "drop" else sample + ) # Batched (input_ids is a list of lists) if handling == "drop": @@ -305,31 +309,33 @@ def truncate_or_drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, han # Check each sequence in the batch for i, seq in enumerate(input_ids): length = len(seq) - + # Skip sequences that are too short if length < min_sequence_len: continue - + # Truncate sequences that are too long if length > sequence_len: input_ids[i] = seq[:sequence_len] - + # Also truncate attention_mask if present if "attention_mask" in sample: - sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len] - + sample["attention_mask"][i] = sample["attention_mask"][i][ + :sequence_len + ] + # Also truncate labels if present if "labels" in sample: sample["labels"][i] = sample["labels"][i][:sequence_len] - + # Also truncate position_ids if present if "position_ids" in sample: sample["position_ids"][i] = sample["position_ids"][i][:sequence_len] - + # Update length if present if "length" in sample: sample["length"][i] = sequence_len - + return sample @@ -468,10 +474,14 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): def process_pretraining_datasets_for_packing( - train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False, handling="drop" + train_dataset, + sequence_len, + skip_position_ids=True, + drop_attention_mask=False, + handling="drop", ): drop_long_fn = partial(drop_long_seq, sequence_len=sequence_len) - + # Use filter for drop mode and map for truncate mode if handling == "drop": train_dataset = train_dataset.filter( @@ -480,13 +490,15 @@ def process_pretraining_datasets_for_packing( load_from_cache_file=False, ) else: - truncate_fn = partial(truncate_or_drop_long_seq, sequence_len=sequence_len, handling=handling) + truncate_fn = partial( + truncate_or_drop_long_seq, sequence_len=sequence_len, handling=handling + ) train_dataset = train_dataset.map( truncate_fn, desc="Truncating Long Sequences", load_from_cache_file=False, ) - + if not skip_position_ids: train_dataset = train_dataset.map( add_position_ids,