Remove transscribe.py file and clean up optimizer.py and rl.py for improved formatting and consistency.
This commit is contained in:
@@ -185,12 +185,12 @@ class OptimizerMixin(Trainer):
|
|||||||
p.data_ptr(): p.numel() for p in module.parameters()
|
p.data_ptr(): p.numel() for p in module.parameters()
|
||||||
}.values()
|
}.values()
|
||||||
)
|
)
|
||||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
LOG.info(f"skipped {module}: {skipped / 2 ** 20}M params")
|
||||||
manager.register_module_override(
|
manager.register_module_override(
|
||||||
module, "weight", {"optim_bits": 32}
|
module, "weight", {"optim_bits": 32}
|
||||||
)
|
)
|
||||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
LOG.info(f"skipped: {skipped / 2 ** 20}M params")
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Literal, List, Union
|
from typing import Any, Callable, Literal
|
||||||
|
|
||||||
from datasets import Dataset, DatasetDict
|
from datasets import Dataset, DatasetDict
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
@@ -120,6 +120,8 @@ def _map_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"):
|
def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"):
|
||||||
"""
|
"""
|
||||||
Backward-compatibility wrapper for legacy imports in tests.
|
Backward-compatibility wrapper for legacy imports in tests.
|
||||||
@@ -128,7 +130,6 @@ def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"):
|
|||||||
return _drop_long_sequences(sample, rl, tokenizer, sequence_len)
|
return _drop_long_sequences(sample, rl, tokenizer, sequence_len)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _drop_long_sequences(
|
def _drop_long_sequences(
|
||||||
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@@ -163,8 +164,8 @@ def _drop_long_sequences(
|
|||||||
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
||||||
|
|
||||||
# Truncate first, then drop if still invalid (although truncate should handle it)
|
# Truncate first, then drop if still invalid (although truncate should handle it)
|
||||||
handling = sample.get("sequence_len_overflow_handling", "drop")
|
handling_mode = sample.get("sequence_len_overflow_handling", "drop")
|
||||||
if handling == "truncate":
|
if handling_mode == "truncate":
|
||||||
# If both sequences fit, return sample unchanged
|
# If both sequences fit, return sample unchanged
|
||||||
if (len_prompt + len_chosen) <= sequence_len and (
|
if (len_prompt + len_chosen) <= sequence_len and (
|
||||||
len_prompt + len_rejected
|
len_prompt + len_rejected
|
||||||
@@ -176,13 +177,12 @@ def _drop_long_sequences(
|
|||||||
|
|
||||||
if max_response_len <= 0:
|
if max_response_len <= 0:
|
||||||
# Prompt itself exceeds sequence length. Cannot truncate responses to fix it.
|
# Prompt itself exceeds sequence length. Cannot truncate responses to fix it.
|
||||||
# Keep sample shape for map(), but log a warning. A subsequent filter will drop it.
|
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Prompt length (%s) exceeds sequence length (%s) for DPO-like sample; will be dropped post-truncation",
|
"Prompt length (%s) exceeds sequence length (%s) for DPO-like sample; dropping",
|
||||||
len_prompt,
|
len_prompt,
|
||||||
sequence_len,
|
sequence_len,
|
||||||
)
|
)
|
||||||
result = sample
|
result = False
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Truncate the chosen and rejected responses if needed
|
# Truncate the chosen and rejected responses if needed
|
||||||
@@ -220,8 +220,8 @@ def _drop_long_sequences(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Truncate first
|
# Truncate first
|
||||||
handling = sample.get("sequence_len_overflow_handling", "drop")
|
handling_mode = sample.get("sequence_len_overflow_handling", "drop")
|
||||||
if handling == "truncate":
|
if handling_mode == "truncate":
|
||||||
# If sequence fits, return sample unchanged
|
# If sequence fits, return sample unchanged
|
||||||
if (len_prompt + len_completion) <= sequence_len:
|
if (len_prompt + len_completion) <= sequence_len:
|
||||||
result = sample
|
result = sample
|
||||||
@@ -232,11 +232,11 @@ def _drop_long_sequences(
|
|||||||
if max_completion_len <= 0:
|
if max_completion_len <= 0:
|
||||||
# Prompt itself exceeds sequence length. Cannot truncate completion to fix it.
|
# Prompt itself exceeds sequence length. Cannot truncate completion to fix it.
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Prompt length (%s) exceeds sequence length (%s) for KTO sample; will be dropped post-truncation",
|
"Prompt length (%s) exceeds sequence length (%s) for KTO sample; dropping",
|
||||||
len_prompt,
|
len_prompt,
|
||||||
sequence_len,
|
sequence_len,
|
||||||
)
|
)
|
||||||
result = sample
|
result = False
|
||||||
else:
|
else:
|
||||||
# Truncate the completion if needed
|
# Truncate the completion if needed
|
||||||
if len_completion > max_completion_len:
|
if len_completion > max_completion_len:
|
||||||
@@ -256,7 +256,7 @@ def _drop_long_sequences(
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Unknown RL type")
|
raise ValueError("Unknown RL type")
|
||||||
|
|
||||||
return result
|
return bool(result)
|
||||||
|
|
||||||
|
|
||||||
def load_prepare_preference_datasets(cfg):
|
def load_prepare_preference_datasets(cfg):
|
||||||
@@ -296,14 +296,9 @@ def load_prepare_preference_datasets(cfg):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def load_split(dataset_cfgs, _cfg):
|
# Legacy shim preserved for backward compatibility; no-op in new flow
|
||||||
split_datasets: List[Any] = []
|
def load_split(dataset_cfgs, _cfg): # noqa: F811
|
||||||
use_auth_token = _cfg.hf_use_auth_token
|
return None
|
||||||
for config_dataset in datasets_w_name_generator(dataset_cfgs):
|
|
||||||
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
|
||||||
config_dataset, use_auth_token, streaming=False
|
|
||||||
)
|
|
||||||
split_datasets.append(ds)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
import pymongo
|
|
||||||
|
|
||||||
MONGO_URI = "mongodb://root:9AsYmXYKmYLHcNsShmCb3L5DZMXH77rQ9GBRxm0HKownNWLwdzH9dW7zhPG9mpuR@46.4.101.229:8281/?directConnection=true"
|
|
||||||
COLLECTION_NAME = "tts_data"
|
|
||||||
|
|
||||||
client = pymongo.MongoClient(MONGO_URI)
|
|
||||||
db = client["tts_data"]
|
|
||||||
collection = db[COLLECTION_NAME]
|
|
||||||
|
|
||||||
# Get all documents from the collection that does not have a "transcription" field
|
|
||||||
documents = collection.find({"transcription": {"$exists": False}})
|
|
||||||
|
|
||||||
for document in documents:
|
|
||||||
print(document)
|
|
||||||
break
|
|
||||||
Reference in New Issue
Block a user