fix: improve handling of cosine hyp
This commit is contained in:
@@ -227,10 +227,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
training_args_kwargs["lr_scheduler_kwargs"] = (
|
training_args_kwargs["lr_scheduler_kwargs"] = (
|
||||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||||
)
|
)
|
||||||
training_args_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
|
|
||||||
training_args_kwargs["cosine_constant_lr_ratio"] = (
|
|
||||||
self.cfg.cosine_constant_lr_ratio
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle custom optimizer
|
# Handle custom optimizer
|
||||||
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
|
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
|
||||||
@@ -444,11 +440,15 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
|
|
||||||
# set arg into trainer_args_kwargs with same name if value not None
|
# set arg into trainer_args_kwargs with same name if value not None
|
||||||
for arg in [
|
for arg in [
|
||||||
|
# optim/scheduler
|
||||||
"adam_beta1",
|
"adam_beta1",
|
||||||
"adam_beta2",
|
"adam_beta2",
|
||||||
"adam_beta3",
|
"adam_beta3",
|
||||||
"adam_epsilon",
|
"adam_epsilon",
|
||||||
"adam_epsilon2",
|
"adam_epsilon2",
|
||||||
|
"cosine_min_lr_ratio",
|
||||||
|
"cosine_constant_lr_ratio",
|
||||||
|
# trainer
|
||||||
"max_grad_norm",
|
"max_grad_norm",
|
||||||
"dataloader_num_workers",
|
"dataloader_num_workers",
|
||||||
"dataloader_pin_memory",
|
"dataloader_pin_memory",
|
||||||
|
|||||||
@@ -216,6 +216,7 @@ class TestHFRLTrainerBuilder:
|
|||||||
assert training_arguments.lr_scheduler_type == "cosine"
|
assert training_arguments.lr_scheduler_type == "cosine"
|
||||||
assert training_arguments.warmup_steps == 10
|
assert training_arguments.warmup_steps == 10
|
||||||
assert training_arguments.cosine_min_lr_ratio == 0.1
|
assert training_arguments.cosine_min_lr_ratio == 0.1
|
||||||
|
assert training_arguments.cosine_constant_lr_ratio == 0.2
|
||||||
|
|
||||||
# Other settings
|
# Other settings
|
||||||
assert training_arguments.dataloader_num_workers == 1
|
assert training_arguments.dataloader_num_workers == 1
|
||||||
|
|||||||
Reference in New Issue
Block a user