Remove transscribe.py file and clean up optimizer.py and rl.py for improved formatting and consistency.

This commit is contained in:
mhenrhcsen
2025-08-12 21:20:48 +02:00
parent dc5887c652
commit f1a8474400
3 changed files with 17 additions and 37 deletions

View File

@@ -185,12 +185,12 @@ class OptimizerMixin(Trainer):
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
LOG.info(f"skipped {module}: {skipped / 2 ** 20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
LOG.info(f"skipped: {skipped / 2 ** 20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init

View File

@@ -2,7 +2,7 @@
import inspect
from functools import partial
from typing import Any, Callable, Literal, List, Union
from typing import Any, Callable, Literal
from datasets import Dataset, DatasetDict
from transformers import PreTrainedTokenizer
@@ -120,6 +120,8 @@ def _map_dataset(
)
return dataset
def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"):
"""
Backward-compatibility wrapper for legacy imports in tests.
@@ -128,7 +130,6 @@ def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"):
return _drop_long_sequences(sample, rl, tokenizer, sequence_len)
def _drop_long_sequences(
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
) -> bool:
@@ -163,8 +164,8 @@ def _drop_long_sequences(
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
# Truncate first, then drop if still invalid (although truncate should handle it)
handling = sample.get("sequence_len_overflow_handling", "drop")
if handling == "truncate":
handling_mode = sample.get("sequence_len_overflow_handling", "drop")
if handling_mode == "truncate":
# If both sequences fit, return sample unchanged
if (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
@@ -176,13 +177,12 @@ def _drop_long_sequences(
if max_response_len <= 0:
# Prompt itself exceeds sequence length. Cannot truncate responses to fix it.
# Keep sample shape for map(), but log a warning. A subsequent filter will drop it.
LOG.warning(
"Prompt length (%s) exceeds sequence length (%s) for DPO-like sample; will be dropped post-truncation",
"Prompt length (%s) exceeds sequence length (%s) for DPO-like sample; dropping",
len_prompt,
sequence_len,
)
result = sample
result = False
else:
# Truncate the chosen and rejected responses if needed
@@ -220,8 +220,8 @@ def _drop_long_sequences(
)
# Truncate first
handling = sample.get("sequence_len_overflow_handling", "drop")
if handling == "truncate":
handling_mode = sample.get("sequence_len_overflow_handling", "drop")
if handling_mode == "truncate":
# If sequence fits, return sample unchanged
if (len_prompt + len_completion) <= sequence_len:
result = sample
@@ -232,11 +232,11 @@ def _drop_long_sequences(
if max_completion_len <= 0:
# Prompt itself exceeds sequence length. Cannot truncate completion to fix it.
LOG.warning(
"Prompt length (%s) exceeds sequence length (%s) for KTO sample; will be dropped post-truncation",
"Prompt length (%s) exceeds sequence length (%s) for KTO sample; dropping",
len_prompt,
sequence_len,
)
result = sample
result = False
else:
# Truncate the completion if needed
if len_completion > max_completion_len:
@@ -256,7 +256,7 @@ def _drop_long_sequences(
else:
raise ValueError("Unknown RL type")
return result
return bool(result)
def load_prepare_preference_datasets(cfg):
@@ -296,14 +296,9 @@ def load_prepare_preference_datasets(cfg):
return True
return False
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
use_auth_token = _cfg.hf_use_auth_token
for config_dataset in datasets_w_name_generator(dataset_cfgs):
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
config_dataset, use_auth_token, streaming=False
)
split_datasets.append(ds)
# Legacy shim preserved for backward compatibility; no-op in new flow
def load_split(dataset_cfgs, _cfg): # noqa: F811
return None
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:

View File

@@ -1,15 +0,0 @@
import pymongo
MONGO_URI = "mongodb://root:9AsYmXYKmYLHcNsShmCb3L5DZMXH77rQ9GBRxm0HKownNWLwdzH9dW7zhPG9mpuR@46.4.101.229:8281/?directConnection=true"
COLLECTION_NAME = "tts_data"
client = pymongo.MongoClient(MONGO_URI)
db = client["tts_data"]
collection = db[COLLECTION_NAME]
# Get all documents from the collection that does not have a "transcription" field
documents = collection.find({"transcription": {"$exists": False}})
for document in documents:
print(document)
break