max new tokens for online generation

This commit is contained in:
Wing Lian
2026-02-04 11:55:19 -05:00
parent 17b01bfe36
commit 002b1ac967
3 changed files with 5 additions and 3 deletions

View File

@@ -55,6 +55,7 @@ class KDArgs(BaseModel):
) )
kd_online_server_model: str | None = None kd_online_server_model: str | None = None
kd_online_timeout: int | None = 120 kd_online_timeout: int | None = 120
kd_online_max_new_tokens: int | None = 2048
kd_temperature_min: float | None = ( kd_temperature_min: float | None = (
None # kd temperature scheduling during online kd None # kd temperature scheduling during online kd
) )
@@ -75,3 +76,4 @@ class KDTrainingArgsMixin:
kd_normalize_topk: float | None = ( kd_normalize_topk: float | None = (
None # whether to normalize student logits during KD None # whether to normalize student logits during KD
) )
kd_online_max_new_tokens: int | None = None

View File

@@ -116,8 +116,8 @@ class AxolotlOnlineKDTrainer(AxolotlKDTrainer):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.generation_config = GenerationConfig( self.generation_config = GenerationConfig(
max_new_tokens=args.max_new_tokens, max_new_tokens=args.kd_online_max_new_tokens,
temperature=args.temperature, temperature=1.0,
do_sample=True, do_sample=True,
top_k=0, top_k=0,
use_cache=False if args.gradient_checkpointing else True, use_cache=False if args.gradient_checkpointing else True,

View File

@@ -320,7 +320,7 @@ class PatchManager:
else: else:
has_remote_code = False has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is False: if has_remote_code and self.cfg.trust_remote_code is not None:
# If explicitly set in YAML, prefer that # If explicitly set in YAML, prefer that
has_remote_code = self.cfg.trust_remote_code has_remote_code = self.cfg.trust_remote_code