more fixes 20240228 (#1342) [skip ci]
* add missing evals_per_epoch setting * more pydantic fixes * more fixes * move test from normalization to validation * increase eval size for sample packing tests
This commit is contained in:
@@ -13,7 +13,6 @@ from threading import Thread
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
@@ -215,6 +214,8 @@ def do_inference_gradio(
|
|||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
cli_args: TrainerCliArgs,
|
cli_args: TrainerCliArgs,
|
||||||
):
|
):
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||||
|
|||||||
@@ -164,9 +164,6 @@ def normalize_config(cfg):
|
|||||||
]
|
]
|
||||||
) or cfg.is_qwen_derived_model
|
) or cfg.is_qwen_derived_model
|
||||||
|
|
||||||
if isinstance(cfg.learning_rate, str):
|
|
||||||
cfg.learning_rate = float(cfg.learning_rate)
|
|
||||||
|
|
||||||
if isinstance(cfg.pretraining_dataset, dict):
|
if isinstance(cfg.pretraining_dataset, dict):
|
||||||
cfg.pretraining_dataset = [cfg.pretraining_dataset]
|
cfg.pretraining_dataset = [cfg.pretraining_dataset]
|
||||||
|
|
||||||
|
|||||||
@@ -302,6 +302,13 @@ class HyperparametersConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
return batch_size
|
return batch_size
|
||||||
|
|
||||||
|
@field_validator("learning_rate")
|
||||||
|
@classmethod
|
||||||
|
def convert_learning_rate(cls, learning_rate):
|
||||||
|
if learning_rate and isinstance(learning_rate, str):
|
||||||
|
learning_rate = float(learning_rate)
|
||||||
|
return learning_rate
|
||||||
|
|
||||||
|
|
||||||
class ModelOutputConfig(BaseModel):
|
class ModelOutputConfig(BaseModel):
|
||||||
"""model save configuration subset"""
|
"""model save configuration subset"""
|
||||||
@@ -368,6 +375,7 @@ class AxolotlInputConfig(
|
|||||||
rl: Optional[RLType] = None
|
rl: Optional[RLType] = None
|
||||||
|
|
||||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
||||||
|
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
||||||
dataset_prepared_path: Optional[str] = None
|
dataset_prepared_path: Optional[str] = None
|
||||||
dataset_shard_num: Optional[int] = None
|
dataset_shard_num: Optional[int] = None
|
||||||
dataset_shard_idx: Optional[int] = None
|
dataset_shard_idx: Optional[int] = None
|
||||||
@@ -456,6 +464,7 @@ class AxolotlInputConfig(
|
|||||||
warmup_steps: Optional[int] = None
|
warmup_steps: Optional[int] = None
|
||||||
warmup_ratio: Optional[float] = None
|
warmup_ratio: Optional[float] = None
|
||||||
eval_steps: Optional[Union[int, float]] = None
|
eval_steps: Optional[Union[int, float]] = None
|
||||||
|
evals_per_epoch: Optional[Union[int]] = None
|
||||||
evaluation_strategy: Optional[str] = None
|
evaluation_strategy: Optional[str] = None
|
||||||
save_steps: Optional[Union[int, float]] = None
|
save_steps: Optional[Union[int, float]] = None
|
||||||
saves_per_epoch: Optional[int] = None
|
saves_per_epoch: Optional[int] = None
|
||||||
@@ -463,6 +472,7 @@ class AxolotlInputConfig(
|
|||||||
save_total_limit: Optional[int] = None
|
save_total_limit: Optional[int] = None
|
||||||
logging_steps: Optional[int] = None
|
logging_steps: Optional[int] = None
|
||||||
early_stopping_patience: Optional[int] = None
|
early_stopping_patience: Optional[int] = None
|
||||||
|
load_best_model_at_end: Optional[bool] = False
|
||||||
|
|
||||||
neftune_noise_alpha: Optional[float] = None
|
neftune_noise_alpha: Optional[float] = None
|
||||||
|
|
||||||
|
|||||||
@@ -255,7 +255,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
train_dataset.remove_columns(["length"]),
|
train_dataset.remove_columns(["length"]),
|
||||||
batch_sampler=sampler,
|
batch_sampler=sampler,
|
||||||
)
|
)
|
||||||
data_loader_len = len(data_loader) // batch_size
|
data_loader_len = len(data_loader) // cfg.batch_size
|
||||||
actual_eff = sampler.efficiency()
|
actual_eff = sampler.efficiency()
|
||||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"lora_alpha": 64,
|
"lora_alpha": 64,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.2,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"unk_token": "<unk>",
|
||||||
"bos_token": "<s>",
|
"bos_token": "<s>",
|
||||||
|
|||||||
@@ -25,20 +25,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_lr_as_float(self):
|
|
||||||
cfg = (
|
|
||||||
self._get_base_cfg()
|
|
||||||
| DictDefault( # pylint: disable=unsupported-binary-operation
|
|
||||||
{
|
|
||||||
"learning_rate": "5e-5",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
normalize_config(cfg)
|
|
||||||
|
|
||||||
assert cfg.learning_rate == 0.00005
|
|
||||||
|
|
||||||
def test_base_model_config_set_when_empty(self):
|
def test_base_model_config_set_when_empty(self):
|
||||||
cfg = self._get_base_cfg()
|
cfg = self._get_base_cfg()
|
||||||
del cfg.base_model_config
|
del cfg.base_model_config
|
||||||
|
|||||||
@@ -176,6 +176,20 @@ class TestValidation(BaseValidation):
|
|||||||
with pytest.raises(ValueError, match=r".*At least two of*"):
|
with pytest.raises(ValueError, match=r".*At least two of*"):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_lr_as_float(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
|
{
|
||||||
|
"learning_rate": "5e-5",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
|
||||||
|
assert new_cfg.learning_rate == 0.00005
|
||||||
|
|
||||||
def test_qlora(self, minimal_cfg):
|
def test_qlora(self, minimal_cfg):
|
||||||
base_cfg = (
|
base_cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
|
|||||||
Reference in New Issue
Block a user