From f1a847440088565c30552f0c4a2c4e96eb97d470 Mon Sep 17 00:00:00 2001 From: mhenrhcsen Date: Tue, 12 Aug 2025 21:20:48 +0200 Subject: [PATCH] Remove transscribe.py file and clean up optimizer.py and rl.py for improved formatting and consistency. --- src/axolotl/core/trainers/mixins/optimizer.py | 4 +-- src/axolotl/utils/data/rl.py | 35 ++++++++----------- transscribe.py | 15 -------- 3 files changed, 17 insertions(+), 37 deletions(-) delete mode 100644 transscribe.py diff --git a/src/axolotl/core/trainers/mixins/optimizer.py b/src/axolotl/core/trainers/mixins/optimizer.py index a9a9a3992..11ba9f524 100644 --- a/src/axolotl/core/trainers/mixins/optimizer.py +++ b/src/axolotl/core/trainers/mixins/optimizer.py @@ -185,12 +185,12 @@ class OptimizerMixin(Trainer): p.data_ptr(): p.numel() for p in module.parameters() }.values() ) - LOG.info(f"skipped {module}: {skipped/2**20}M params") + LOG.info(f"skipped {module}: {skipped / 2 ** 20}M params") manager.register_module_override( module, "weight", {"optim_bits": 32} ) LOG.debug(f"bitsandbytes: will optimize {module} in fp32") - LOG.info(f"skipped: {skipped/2**20}M params") + LOG.info(f"skipped: {skipped / 2 ** 20}M params") if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index f8a839b74..3feb5de49 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -2,7 +2,7 @@ import inspect from functools import partial -from typing import Any, Callable, Literal, List, Union +from typing import Any, Callable, Literal from datasets import Dataset, DatasetDict from transformers import PreTrainedTokenizer @@ -120,6 +120,8 @@ def _map_dataset( ) return dataset + + def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"): """ Backward-compatibility wrapper for legacy imports in tests. @@ -128,7 +130,6 @@ def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"): return _drop_long_sequences(sample, rl, tokenizer, sequence_len) - def _drop_long_sequences( sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int ) -> bool: @@ -163,8 +164,8 @@ def _drop_long_sequences( len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"]) # Truncate first, then drop if still invalid (although truncate should handle it) - handling = sample.get("sequence_len_overflow_handling", "drop") - if handling == "truncate": + handling_mode = sample.get("sequence_len_overflow_handling", "drop") + if handling_mode == "truncate": # If both sequences fit, return sample unchanged if (len_prompt + len_chosen) <= sequence_len and ( len_prompt + len_rejected @@ -176,13 +177,12 @@ def _drop_long_sequences( if max_response_len <= 0: # Prompt itself exceeds sequence length. Cannot truncate responses to fix it. - # Keep sample shape for map(), but log a warning. A subsequent filter will drop it. LOG.warning( - "Prompt length (%s) exceeds sequence length (%s) for DPO-like sample; will be dropped post-truncation", + "Prompt length (%s) exceeds sequence length (%s) for DPO-like sample; dropping", len_prompt, sequence_len, ) - result = sample + result = False else: # Truncate the chosen and rejected responses if needed @@ -220,8 +220,8 @@ def _drop_long_sequences( ) # Truncate first - handling = sample.get("sequence_len_overflow_handling", "drop") - if handling == "truncate": + handling_mode = sample.get("sequence_len_overflow_handling", "drop") + if handling_mode == "truncate": # If sequence fits, return sample unchanged if (len_prompt + len_completion) <= sequence_len: result = sample @@ -232,11 +232,11 @@ def _drop_long_sequences( if max_completion_len <= 0: # Prompt itself exceeds sequence length. Cannot truncate completion to fix it. LOG.warning( - "Prompt length (%s) exceeds sequence length (%s) for KTO sample; will be dropped post-truncation", + "Prompt length (%s) exceeds sequence length (%s) for KTO sample; dropping", len_prompt, sequence_len, ) - result = sample + result = False else: # Truncate the completion if needed if len_completion > max_completion_len: @@ -256,7 +256,7 @@ def _drop_long_sequences( else: raise ValueError("Unknown RL type") - return result + return bool(result) def load_prepare_preference_datasets(cfg): @@ -296,14 +296,9 @@ def load_prepare_preference_datasets(cfg): return True return False - def load_split(dataset_cfgs, _cfg): - split_datasets: List[Any] = [] - use_auth_token = _cfg.hf_use_auth_token - for config_dataset in datasets_w_name_generator(dataset_cfgs): - ds: Union[Dataset, DatasetDict] = load_dataset_w_config( - config_dataset, use_auth_token, streaming=False - ) - split_datasets.append(ds) + # Legacy shim preserved for backward compatibility; no-op in new flow + def load_split(dataset_cfgs, _cfg): # noqa: F811 + return None def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset: diff --git a/transscribe.py b/transscribe.py deleted file mode 100644 index b5e16c984..000000000 --- a/transscribe.py +++ /dev/null @@ -1,15 +0,0 @@ -import pymongo - -MONGO_URI = "mongodb://root:9AsYmXYKmYLHcNsShmCb3L5DZMXH77rQ9GBRxm0HKownNWLwdzH9dW7zhPG9mpuR@46.4.101.229:8281/?directConnection=true" -COLLECTION_NAME = "tts_data" - -client = pymongo.MongoClient(MONGO_URI) -db = client["tts_data"] -collection = db[COLLECTION_NAME] - -# Get all documents from the collection that does not have a "transcription" field -documents = collection.find({"transcription": {"$exists": False}}) - -for document in documents: - print(document) - break