use kwargs
This commit is contained in:
@@ -116,11 +116,11 @@ class AxolotlOnlineKDTrainer(AxolotlKDTrainer):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.generation_config = GenerationConfig(
|
||||
max_new_tokens=args.kd_online_max_new_tokens,
|
||||
max_new_tokens=kwargs.get("kd_online_max_new_tokens"),
|
||||
temperature=1.0,
|
||||
do_sample=True,
|
||||
top_k=0,
|
||||
use_cache=False if args.gradient_checkpointing else True,
|
||||
use_cache=False if kwargs.get("gradient_checkpointing") else True,
|
||||
pad_token_id=self.processing_class.pad_token_id,
|
||||
)
|
||||
# Set custom EOS tokens if they are specified by the model's generation
|
||||
|
||||
Reference in New Issue
Block a user