max new tokens for online generation
This commit is contained in:
@@ -55,6 +55,7 @@ class KDArgs(BaseModel):
|
|||||||
)
|
)
|
||||||
kd_online_server_model: str | None = None
|
kd_online_server_model: str | None = None
|
||||||
kd_online_timeout: int | None = 120
|
kd_online_timeout: int | None = 120
|
||||||
|
kd_online_max_new_tokens: int | None = 2048
|
||||||
kd_temperature_min: float | None = (
|
kd_temperature_min: float | None = (
|
||||||
None # kd temperature scheduling during online kd
|
None # kd temperature scheduling during online kd
|
||||||
)
|
)
|
||||||
@@ -75,3 +76,4 @@ class KDTrainingArgsMixin:
|
|||||||
kd_normalize_topk: float | None = (
|
kd_normalize_topk: float | None = (
|
||||||
None # whether to normalize student logits during KD
|
None # whether to normalize student logits during KD
|
||||||
)
|
)
|
||||||
|
kd_online_max_new_tokens: int | None = None
|
||||||
|
|||||||
@@ -116,8 +116,8 @@ class AxolotlOnlineKDTrainer(AxolotlKDTrainer):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.generation_config = GenerationConfig(
|
self.generation_config = GenerationConfig(
|
||||||
max_new_tokens=args.max_new_tokens,
|
max_new_tokens=args.kd_online_max_new_tokens,
|
||||||
temperature=args.temperature,
|
temperature=1.0,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
top_k=0,
|
top_k=0,
|
||||||
use_cache=False if args.gradient_checkpointing else True,
|
use_cache=False if args.gradient_checkpointing else True,
|
||||||
|
|||||||
@@ -320,7 +320,7 @@ class PatchManager:
|
|||||||
else:
|
else:
|
||||||
has_remote_code = False
|
has_remote_code = False
|
||||||
|
|
||||||
if has_remote_code and self.cfg.trust_remote_code is False:
|
if has_remote_code and self.cfg.trust_remote_code is not None:
|
||||||
# If explicitly set in YAML, prefer that
|
# If explicitly set in YAML, prefer that
|
||||||
has_remote_code = self.cfg.trust_remote_code
|
has_remote_code = self.cfg.trust_remote_code
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user