From 002b1ac96789e2ec63793d4bd1a6cc63c69b0277 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 4 Feb 2026 11:55:19 -0500 Subject: [PATCH] max new tokens for online generation --- src/axolotl/integrations/kd/args.py | 2 ++ src/axolotl/integrations/kd/trainer.py | 4 ++-- src/axolotl/loaders/patch_manager.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py index 7cc7f1ed2..402ff5acd 100644 --- a/src/axolotl/integrations/kd/args.py +++ b/src/axolotl/integrations/kd/args.py @@ -55,6 +55,7 @@ class KDArgs(BaseModel): ) kd_online_server_model: str | None = None kd_online_timeout: int | None = 120 + kd_online_max_new_tokens: int | None = 2048 kd_temperature_min: float | None = ( None # kd temperature scheduling during online kd ) @@ -75,3 +76,4 @@ class KDTrainingArgsMixin: kd_normalize_topk: float | None = ( None # whether to normalize student logits during KD ) + kd_online_max_new_tokens: int | None = None diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index 0a86fbe0b..7677474db 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -116,8 +116,8 @@ class AxolotlOnlineKDTrainer(AxolotlKDTrainer): super().__init__(*args, **kwargs) self.generation_config = GenerationConfig( - max_new_tokens=args.max_new_tokens, - temperature=args.temperature, + max_new_tokens=args.kd_online_max_new_tokens, + temperature=1.0, do_sample=True, top_k=0, use_cache=False if args.gradient_checkpointing else True, diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 30c3ba0fd..61f75c310 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -320,7 +320,7 @@ class PatchManager: else: has_remote_code = False - if has_remote_code and self.cfg.trust_remote_code is False: + if has_remote_code and self.cfg.trust_remote_code is not None: # If explicitly set in YAML, prefer that has_remote_code = self.cfg.trust_remote_code