use kd_alpha in the correct loss method

This commit is contained in:
Wing Lian
2024-12-24 19:54:32 -05:00
parent 8d77dc385e
commit 53ec07d44c
2 changed files with 3 additions and 2 deletions

View File

@@ -55,7 +55,7 @@ def add_options_from_config(config_class: Type[BaseModel]):
def decorator(function):
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool:
if field.annotation in (bool, Optional[bool]):
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(

View File

@@ -182,7 +182,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
)
if self.args.kd_ce_alpha > 0:
loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd
kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else:
loss = loss_kd
# Save past state if it exists