Feat: Add warmup_ratio (#893)
* Feat: Add warmup_ratio * fix: update readme with more details on conflict
This commit is contained in:
@@ -461,11 +461,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
warmup_steps = (
|
||||
self.cfg.warmup_steps
|
||||
if self.cfg.warmup_steps is not None
|
||||
else min(int(0.03 * total_num_steps), 100)
|
||||
)
|
||||
warmup_steps = None
|
||||
if self.cfg.warmup_steps is not None:
|
||||
warmup_steps = self.cfg.warmup_steps
|
||||
elif self.cfg.warmup_ratio is not None:
|
||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||
else:
|
||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||
|
||||
logging_steps = (
|
||||
self.cfg.logging_steps
|
||||
if self.cfg.logging_steps is not None
|
||||
|
||||
@@ -372,6 +372,9 @@ def validate_config(cfg):
|
||||
if cfg.rope_scaling:
|
||||
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
||||
|
||||
if cfg.warmup_steps and cfg.warmup_ratio:
|
||||
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
|
||||
Reference in New Issue
Block a user