Remove transscribe.py file and clean up optimizer.py and rl.py for improved formatting and consistency.

This commit is contained in:
mhenrhcsen
2025-08-12 21:20:48 +02:00
parent dc5887c652
commit f1a8474400
3 changed files with 17 additions and 37 deletions

View File

@@ -185,12 +185,12 @@ class OptimizerMixin(Trainer):
p.data_ptr(): p.numel() for p in module.parameters() p.data_ptr(): p.numel() for p in module.parameters()
}.values() }.values()
) )
LOG.info(f"skipped {module}: {skipped/2**20}M params") LOG.info(f"skipped {module}: {skipped / 2 ** 20}M params")
manager.register_module_override( manager.register_module_override(
module, "weight", {"optim_bits": 32} module, "weight", {"optim_bits": 32}
) )
LOG.debug(f"bitsandbytes: will optimize {module} in fp32") LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params") LOG.info(f"skipped: {skipped / 2 ** 20}M params")
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init

View File

@@ -2,7 +2,7 @@
import inspect import inspect
from functools import partial from functools import partial
from typing import Any, Callable, Literal, List, Union from typing import Any, Callable, Literal
from datasets import Dataset, DatasetDict from datasets import Dataset, DatasetDict
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@@ -120,6 +120,8 @@ def _map_dataset(
) )
return dataset return dataset
def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"): def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"):
""" """
Backward-compatibility wrapper for legacy imports in tests. Backward-compatibility wrapper for legacy imports in tests.
@@ -128,7 +130,6 @@ def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"):
return _drop_long_sequences(sample, rl, tokenizer, sequence_len) return _drop_long_sequences(sample, rl, tokenizer, sequence_len)
def _drop_long_sequences( def _drop_long_sequences(
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
) -> bool: ) -> bool:
@@ -163,8 +164,8 @@ def _drop_long_sequences(
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"]) len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
# Truncate first, then drop if still invalid (although truncate should handle it) # Truncate first, then drop if still invalid (although truncate should handle it)
handling = sample.get("sequence_len_overflow_handling", "drop") handling_mode = sample.get("sequence_len_overflow_handling", "drop")
if handling == "truncate": if handling_mode == "truncate":
# If both sequences fit, return sample unchanged # If both sequences fit, return sample unchanged
if (len_prompt + len_chosen) <= sequence_len and ( if (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected len_prompt + len_rejected
@@ -176,13 +177,12 @@ def _drop_long_sequences(
if max_response_len <= 0: if max_response_len <= 0:
# Prompt itself exceeds sequence length. Cannot truncate responses to fix it. # Prompt itself exceeds sequence length. Cannot truncate responses to fix it.
# Keep sample shape for map(), but log a warning. A subsequent filter will drop it.
LOG.warning( LOG.warning(
"Prompt length (%s) exceeds sequence length (%s) for DPO-like sample; will be dropped post-truncation", "Prompt length (%s) exceeds sequence length (%s) for DPO-like sample; dropping",
len_prompt, len_prompt,
sequence_len, sequence_len,
) )
result = sample result = False
else: else:
# Truncate the chosen and rejected responses if needed # Truncate the chosen and rejected responses if needed
@@ -220,8 +220,8 @@ def _drop_long_sequences(
) )
# Truncate first # Truncate first
handling = sample.get("sequence_len_overflow_handling", "drop") handling_mode = sample.get("sequence_len_overflow_handling", "drop")
if handling == "truncate": if handling_mode == "truncate":
# If sequence fits, return sample unchanged # If sequence fits, return sample unchanged
if (len_prompt + len_completion) <= sequence_len: if (len_prompt + len_completion) <= sequence_len:
result = sample result = sample
@@ -232,11 +232,11 @@ def _drop_long_sequences(
if max_completion_len <= 0: if max_completion_len <= 0:
# Prompt itself exceeds sequence length. Cannot truncate completion to fix it. # Prompt itself exceeds sequence length. Cannot truncate completion to fix it.
LOG.warning( LOG.warning(
"Prompt length (%s) exceeds sequence length (%s) for KTO sample; will be dropped post-truncation", "Prompt length (%s) exceeds sequence length (%s) for KTO sample; dropping",
len_prompt, len_prompt,
sequence_len, sequence_len,
) )
result = sample result = False
else: else:
# Truncate the completion if needed # Truncate the completion if needed
if len_completion > max_completion_len: if len_completion > max_completion_len:
@@ -256,7 +256,7 @@ def _drop_long_sequences(
else: else:
raise ValueError("Unknown RL type") raise ValueError("Unknown RL type")
return result return bool(result)
def load_prepare_preference_datasets(cfg): def load_prepare_preference_datasets(cfg):
@@ -296,14 +296,9 @@ def load_prepare_preference_datasets(cfg):
return True return True
return False return False
def load_split(dataset_cfgs, _cfg): # Legacy shim preserved for backward compatibility; no-op in new flow
split_datasets: List[Any] = [] def load_split(dataset_cfgs, _cfg): # noqa: F811
use_auth_token = _cfg.hf_use_auth_token return None
for config_dataset in datasets_w_name_generator(dataset_cfgs):
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
config_dataset, use_auth_token, streaming=False
)
split_datasets.append(ds)
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset: def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:

View File

@@ -1,15 +0,0 @@
import pymongo
MONGO_URI = "mongodb://root:9AsYmXYKmYLHcNsShmCb3L5DZMXH77rQ9GBRxm0HKownNWLwdzH9dW7zhPG9mpuR@46.4.101.229:8281/?directConnection=true"
COLLECTION_NAME = "tts_data"
client = pymongo.MongoClient(MONGO_URI)
db = client["tts_data"]
collection = db[COLLECTION_NAME]
# Get all documents from the collection that does not have a "transcription" field
documents = collection.find({"transcription": {"$exists": False}})
for document in documents:
print(document)
break