use kwargs

This commit is contained in:
Wing Lian
2026-02-04 12:04:53 -05:00
parent 002b1ac967
commit b8d52a2193

View File

@@ -116,11 +116,11 @@ class AxolotlOnlineKDTrainer(AxolotlKDTrainer):
super().__init__(*args, **kwargs)
self.generation_config = GenerationConfig(
max_new_tokens=args.kd_online_max_new_tokens,
max_new_tokens=kwargs.get("kd_online_max_new_tokens"),
temperature=1.0,
do_sample=True,
top_k=0,
use_cache=False if args.gradient_checkpointing else True,
use_cache=False if kwargs.get("gradient_checkpointing") else True,
pad_token_id=self.processing_class.pad_token_id,
)
# Set custom EOS tokens if they are specified by the model's generation