remove unused
This commit is contained in:
@@ -1456,21 +1456,6 @@ class StreamingValidationMixin:
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def check_streaming_skip_prepare_dataset(self):
|
|
||||||
"""Ensure skip_prepare_dataset is set for streaming datasets."""
|
|
||||||
# Check if streaming is enabled for training datasets
|
|
||||||
if self._is_streaming_enabled():
|
|
||||||
skip_prepare = getattr(self, "skip_prepare_dataset", None)
|
|
||||||
if skip_prepare is False:
|
|
||||||
LOG.warning(
|
|
||||||
"skip_prepare_dataset=False is not compatible with streaming "
|
|
||||||
"datasets. Setting skip_prepare_dataset=True."
|
|
||||||
)
|
|
||||||
self.skip_prepare_dataset = True
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_dataset_mixing_weights(self):
|
def check_dataset_mixing_weights(self):
|
||||||
"""Validate dataset mixing weights configuration."""
|
"""Validate dataset mixing weights configuration."""
|
||||||
|
|||||||
@@ -85,84 +85,6 @@ class TestStreamingDatasets:
|
|||||||
"Train Loss (%s) is too high",
|
"Train Loss (%s) is too high",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_streaming_eval_specific_mixing(self, temp_dir):
|
|
||||||
"""Test eval-specific mixing strategy override"""
|
|
||||||
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"flash_attention": True,
|
|
||||||
"sequence_len": 512,
|
|
||||||
"sample_packing": False,
|
|
||||||
"dataset_processes": 1,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": "tatsu-lab/alpaca",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"test_datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
"split": "train",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": "tatsu-lab/alpaca",
|
|
||||||
"type": "alpaca",
|
|
||||||
"split": "train",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
# Streaming config
|
|
||||||
"streaming": True,
|
|
||||||
"eval_streaming": True,
|
|
||||||
"max_steps": 3,
|
|
||||||
# Different mixing for train vs eval
|
|
||||||
"dataset_mixing_strategy": "round_robin",
|
|
||||||
"eval_dataset_mixing_strategy": "weighted",
|
|
||||||
"eval_mixing_weights": [0.6, 0.4],
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"save_safetensors": True,
|
|
||||||
"bf16": "auto",
|
|
||||||
"use_tensorboard": True,
|
|
||||||
"save_first_step": False,
|
|
||||||
"eval_steps": 3, # Eval at the end
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
dataset_meta = load_datasets(cfg=cfg)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
check_model_output_exists(temp_dir, cfg)
|
|
||||||
|
|
||||||
# Check both train and eval losses
|
|
||||||
check_tensorboard(
|
|
||||||
temp_dir + "/runs",
|
|
||||||
"train/train_loss",
|
|
||||||
2.5,
|
|
||||||
"Train Loss (%s) is too high",
|
|
||||||
)
|
|
||||||
check_tensorboard(
|
|
||||||
temp_dir + "/runs",
|
|
||||||
"eval/eval_loss",
|
|
||||||
2.5,
|
|
||||||
"Eval Loss (%s) is too high",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_streaming_validation_error(self, temp_dir):
|
def test_streaming_validation_error(self, temp_dir):
|
||||||
"""Test that pydantic validation catches invalid streaming configs"""
|
"""Test that pydantic validation catches invalid streaming configs"""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user