max new tokens for online generation
This commit is contained in:
@@ -55,6 +55,7 @@ class KDArgs(BaseModel):
|
||||
)
|
||||
kd_online_server_model: str | None = None
|
||||
kd_online_timeout: int | None = 120
|
||||
kd_online_max_new_tokens: int | None = 2048
|
||||
kd_temperature_min: float | None = (
|
||||
None # kd temperature scheduling during online kd
|
||||
)
|
||||
@@ -75,3 +76,4 @@ class KDTrainingArgsMixin:
|
||||
kd_normalize_topk: float | None = (
|
||||
None # whether to normalize student logits during KD
|
||||
)
|
||||
kd_online_max_new_tokens: int | None = None
|
||||
|
||||
@@ -116,8 +116,8 @@ class AxolotlOnlineKDTrainer(AxolotlKDTrainer):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.generation_config = GenerationConfig(
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
max_new_tokens=args.kd_online_max_new_tokens,
|
||||
temperature=1.0,
|
||||
do_sample=True,
|
||||
top_k=0,
|
||||
use_cache=False if args.gradient_checkpointing else True,
|
||||
|
||||
@@ -320,7 +320,7 @@ class PatchManager:
|
||||
else:
|
||||
has_remote_code = False
|
||||
|
||||
if has_remote_code and self.cfg.trust_remote_code is False:
|
||||
if has_remote_code and self.cfg.trust_remote_code is not None:
|
||||
# If explicitly set in YAML, prefer that
|
||||
has_remote_code = self.cfg.trust_remote_code
|
||||
|
||||
|
||||
Reference in New Issue
Block a user