From 53ec07d44c1d10626e18b84b6f6097dc8e966410 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 24 Dec 2024 19:54:32 -0500 Subject: [PATCH] use kd_alpha in the correct loss method --- src/axolotl/cli/utils.py | 2 +- src/axolotl/core/trainers/kd.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index f0e2573f7..ae7076215 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -55,7 +55,7 @@ def add_options_from_config(config_class: Type[BaseModel]): def decorator(function): # Process model fields in reverse order for correct option ordering for name, field in reversed(config_class.model_fields.items()): - if field.annotation == bool: + if field.annotation in (bool, Optional[bool]): field_name = name.replace("_", "-") option_name = f"--{field_name}/--no-{field_name}" function = click.option( diff --git a/src/axolotl/core/trainers/kd.py b/src/axolotl/core/trainers/kd.py index 6047c72ff..e8adfab41 100644 --- a/src/axolotl/core/trainers/kd.py +++ b/src/axolotl/core/trainers/kd.py @@ -182,7 +182,8 @@ class AxolotlKDTrainer(AxolotlTrainer): ) if self.args.kd_ce_alpha > 0: - loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd + kd_alpha = self.args.kd_alpha + loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd else: loss = loss_kd # Save past state if it exists