max new tokens for online generation

This commit is contained in:
Wing Lian
2026-02-04 11:55:19 -05:00
parent 17b01bfe36
commit 002b1ac967
3 changed files with 5 additions and 3 deletions

View File

@@ -55,6 +55,7 @@ class KDArgs(BaseModel):
)
kd_online_server_model: str | None = None
kd_online_timeout: int | None = 120
kd_online_max_new_tokens: int | None = 2048
kd_temperature_min: float | None = (
None # kd temperature scheduling during online kd
)
@@ -75,3 +76,4 @@ class KDTrainingArgsMixin:
kd_normalize_topk: float | None = (
None # whether to normalize student logits during KD
)
kd_online_max_new_tokens: int | None = None

View File

@@ -116,8 +116,8 @@ class AxolotlOnlineKDTrainer(AxolotlKDTrainer):
super().__init__(*args, **kwargs)
self.generation_config = GenerationConfig(
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
max_new_tokens=args.kd_online_max_new_tokens,
temperature=1.0,
do_sample=True,
top_k=0,
use_cache=False if args.gradient_checkpointing else True,

View File

@@ -320,7 +320,7 @@ class PatchManager:
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is False:
if has_remote_code and self.cfg.trust_remote_code is not None:
# If explicitly set in YAML, prefer that
has_remote_code = self.cfg.trust_remote_code