use kd_alpha in the correct loss method

This commit is contained in:
Wing Lian
2024-12-24 19:54:32 -05:00
parent 8d77dc385e
commit 53ec07d44c
2 changed files with 3 additions and 2 deletions

View File

@@ -55,7 +55,7 @@ def add_options_from_config(config_class: Type[BaseModel]):
def decorator(function): def decorator(function):
# Process model fields in reverse order for correct option ordering # Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()): for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool: if field.annotation in (bool, Optional[bool]):
field_name = name.replace("_", "-") field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}" option_name = f"--{field_name}/--no-{field_name}"
function = click.option( function = click.option(

View File

@@ -182,7 +182,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
) )
if self.args.kd_ce_alpha > 0: if self.args.kd_ce_alpha > 0:
loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else: else:
loss = loss_kd loss = loss_kd
# Save past state if it exists # Save past state if it exists