diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index 7677474db..416084a6a 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -116,11 +116,11 @@ class AxolotlOnlineKDTrainer(AxolotlKDTrainer): super().__init__(*args, **kwargs) self.generation_config = GenerationConfig( - max_new_tokens=args.kd_online_max_new_tokens, + max_new_tokens=kwargs.get("kd_online_max_new_tokens"), temperature=1.0, do_sample=True, top_k=0, - use_cache=False if args.gradient_checkpointing else True, + use_cache=False if kwargs.get("gradient_checkpointing") else True, pad_token_id=self.processing_class.pad_token_id, ) # Set custom EOS tokens if they are specified by the model's generation