diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py index 7cc7f1ed2..402ff5acd 100644 --- a/src/axolotl/integrations/kd/args.py +++ b/src/axolotl/integrations/kd/args.py @@ -55,6 +55,7 @@ class KDArgs(BaseModel): ) kd_online_server_model: str | None = None kd_online_timeout: int | None = 120 + kd_online_max_new_tokens: int | None = 2048 kd_temperature_min: float | None = ( None # kd temperature scheduling during online kd ) @@ -75,3 +76,4 @@ class KDTrainingArgsMixin: kd_normalize_topk: float | None = ( None # whether to normalize student logits during KD ) + kd_online_max_new_tokens: int | None = None diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index 0a86fbe0b..7677474db 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -116,8 +116,8 @@ class AxolotlOnlineKDTrainer(AxolotlKDTrainer): super().__init__(*args, **kwargs) self.generation_config = GenerationConfig( - max_new_tokens=args.max_new_tokens, - temperature=args.temperature, + max_new_tokens=args.kd_online_max_new_tokens, + temperature=1.0, do_sample=True, top_k=0, use_cache=False if args.gradient_checkpointing else True, diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 30c3ba0fd..61f75c310 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -320,7 +320,7 @@ class PatchManager: else: has_remote_code = False - if has_remote_code and self.cfg.trust_remote_code is False: + if has_remote_code and self.cfg.trust_remote_code is not None: # If explicitly set in YAML, prefer that has_remote_code = self.cfg.trust_remote_code