use kwargs

This commit is contained in:
Wing Lian
2026-02-04 12:04:53 -05:00
parent 002b1ac967
commit b8d52a2193

View File

@@ -116,11 +116,11 @@ class AxolotlOnlineKDTrainer(AxolotlKDTrainer):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.generation_config = GenerationConfig( self.generation_config = GenerationConfig(
max_new_tokens=args.kd_online_max_new_tokens, max_new_tokens=kwargs.get("kd_online_max_new_tokens"),
temperature=1.0, temperature=1.0,
do_sample=True, do_sample=True,
top_k=0, top_k=0,
use_cache=False if args.gradient_checkpointing else True, use_cache=False if kwargs.get("gradient_checkpointing") else True,
pad_token_id=self.processing_class.pad_token_id, pad_token_id=self.processing_class.pad_token_id,
) )
# Set custom EOS tokens if they are specified by the model's generation # Set custom EOS tokens if they are specified by the model's generation