more fixes 20240228 (#1342) [skip ci]

* add missing evals_per_epoch setting

* more pydantic fixes

* more fixes

* move test from normalization to validation

* increase eval size for sample packing tests
This commit is contained in:
Wing Lian
2024-02-28 12:57:45 -05:00
committed by GitHub
parent c1a7b3dd69
commit 0f985e12fe
7 changed files with 28 additions and 20 deletions

View File

@@ -13,7 +13,6 @@ from threading import Thread
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import gradio as gr
import requests import requests
import torch import torch
import yaml import yaml
@@ -215,6 +214,8 @@ def do_inference_gradio(
cfg: DictDefault, cfg: DictDefault,
cli_args: TrainerCliArgs, cli_args: TrainerCliArgs,
): ):
import gradio as gr
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"} default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}

View File

@@ -164,9 +164,6 @@ def normalize_config(cfg):
] ]
) or cfg.is_qwen_derived_model ) or cfg.is_qwen_derived_model
if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate)
if isinstance(cfg.pretraining_dataset, dict): if isinstance(cfg.pretraining_dataset, dict):
cfg.pretraining_dataset = [cfg.pretraining_dataset] cfg.pretraining_dataset = [cfg.pretraining_dataset]

View File

@@ -302,6 +302,13 @@ class HyperparametersConfig(BaseModel):
) )
return batch_size return batch_size
@field_validator("learning_rate")
@classmethod
def convert_learning_rate(cls, learning_rate):
if learning_rate and isinstance(learning_rate, str):
learning_rate = float(learning_rate)
return learning_rate
class ModelOutputConfig(BaseModel): class ModelOutputConfig(BaseModel):
"""model save configuration subset""" """model save configuration subset"""
@@ -368,6 +375,7 @@ class AxolotlInputConfig(
rl: Optional[RLType] = None rl: Optional[RLType] = None
datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
dataset_prepared_path: Optional[str] = None dataset_prepared_path: Optional[str] = None
dataset_shard_num: Optional[int] = None dataset_shard_num: Optional[int] = None
dataset_shard_idx: Optional[int] = None dataset_shard_idx: Optional[int] = None
@@ -456,6 +464,7 @@ class AxolotlInputConfig(
warmup_steps: Optional[int] = None warmup_steps: Optional[int] = None
warmup_ratio: Optional[float] = None warmup_ratio: Optional[float] = None
eval_steps: Optional[Union[int, float]] = None eval_steps: Optional[Union[int, float]] = None
evals_per_epoch: Optional[Union[int]] = None
evaluation_strategy: Optional[str] = None evaluation_strategy: Optional[str] = None
save_steps: Optional[Union[int, float]] = None save_steps: Optional[Union[int, float]] = None
saves_per_epoch: Optional[int] = None saves_per_epoch: Optional[int] = None
@@ -463,6 +472,7 @@ class AxolotlInputConfig(
save_total_limit: Optional[int] = None save_total_limit: Optional[int] = None
logging_steps: Optional[int] = None logging_steps: Optional[int] = None
early_stopping_patience: Optional[int] = None early_stopping_patience: Optional[int] = None
load_best_model_at_end: Optional[bool] = False
neftune_noise_alpha: Optional[float] = None neftune_noise_alpha: Optional[float] = None

View File

@@ -255,7 +255,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
train_dataset.remove_columns(["length"]), train_dataset.remove_columns(["length"]),
batch_sampler=sampler, batch_sampler=sampler,
) )
data_loader_len = len(data_loader) // batch_size data_loader_len = len(data_loader) // cfg.batch_size
actual_eff = sampler.efficiency() actual_eff = sampler.efficiency()
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True) LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
# FIXME: is there a bug here somewhere? the total num steps depends # FIXME: is there a bug here somewhere? the total num steps depends

View File

@@ -43,7 +43,7 @@ class TestLoraLlama(unittest.TestCase):
"lora_alpha": 64, "lora_alpha": 64,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"val_set_size": 0.1, "val_set_size": 0.2,
"special_tokens": { "special_tokens": {
"unk_token": "<unk>", "unk_token": "<unk>",
"bos_token": "<s>", "bos_token": "<s>",

View File

@@ -25,20 +25,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
} }
) )
def test_lr_as_float(self):
cfg = (
self._get_base_cfg()
| DictDefault( # pylint: disable=unsupported-binary-operation
{
"learning_rate": "5e-5",
}
)
)
normalize_config(cfg)
assert cfg.learning_rate == 0.00005
def test_base_model_config_set_when_empty(self): def test_base_model_config_set_when_empty(self):
cfg = self._get_base_cfg() cfg = self._get_base_cfg()
del cfg.base_model_config del cfg.base_model_config

View File

@@ -176,6 +176,20 @@ class TestValidation(BaseValidation):
with pytest.raises(ValueError, match=r".*At least two of*"): with pytest.raises(ValueError, match=r".*At least two of*"):
validate_config(cfg) validate_config(cfg)
def test_lr_as_float(self, minimal_cfg):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"learning_rate": "5e-5",
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
assert new_cfg.learning_rate == 0.00005
def test_qlora(self, minimal_cfg): def test_qlora(self, minimal_cfg):
base_cfg = ( base_cfg = (
DictDefault( DictDefault(