Feat: Add warmup_ratio (#893)
* Feat: Add warmup_ratio * fix: update readme with more details on conflict
This commit is contained in:
@@ -675,7 +675,8 @@ gradient_accumulation_steps: 1
|
|||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
eval_batch_size:
|
eval_batch_size:
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
warmup_steps: 100
|
warmup_steps: 100 # cannot use with warmup_ratio
|
||||||
|
warmup_ratio: 0.05 # cannot use with warmup_steps
|
||||||
learning_rate: 0.00003
|
learning_rate: 0.00003
|
||||||
lr_quadratic_warmup:
|
lr_quadratic_warmup:
|
||||||
logging_steps:
|
logging_steps:
|
||||||
|
|||||||
@@ -461,11 +461,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
warmup_steps = (
|
warmup_steps = None
|
||||||
self.cfg.warmup_steps
|
if self.cfg.warmup_steps is not None:
|
||||||
if self.cfg.warmup_steps is not None
|
warmup_steps = self.cfg.warmup_steps
|
||||||
else min(int(0.03 * total_num_steps), 100)
|
elif self.cfg.warmup_ratio is not None:
|
||||||
)
|
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||||
|
else:
|
||||||
|
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||||
|
|
||||||
logging_steps = (
|
logging_steps = (
|
||||||
self.cfg.logging_steps
|
self.cfg.logging_steps
|
||||||
if self.cfg.logging_steps is not None
|
if self.cfg.logging_steps is not None
|
||||||
|
|||||||
@@ -372,6 +372,9 @@ def validate_config(cfg):
|
|||||||
if cfg.rope_scaling:
|
if cfg.rope_scaling:
|
||||||
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
||||||
|
|
||||||
|
if cfg.warmup_steps and cfg.warmup_ratio:
|
||||||
|
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -649,3 +649,33 @@ class ValidationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_warmup_step_no_conflict(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"warmup_steps": 10,
|
||||||
|
"warmup_ratio": 0.1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r".*warmup_steps and warmup_ratio are mutually exclusive*",
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"warmup_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"warmup_ratio": 0.1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user