use kwargs
This commit is contained in:
@@ -116,11 +116,11 @@ class AxolotlOnlineKDTrainer(AxolotlKDTrainer):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.generation_config = GenerationConfig(
|
self.generation_config = GenerationConfig(
|
||||||
max_new_tokens=args.kd_online_max_new_tokens,
|
max_new_tokens=kwargs.get("kd_online_max_new_tokens"),
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
top_k=0,
|
top_k=0,
|
||||||
use_cache=False if args.gradient_checkpointing else True,
|
use_cache=False if kwargs.get("gradient_checkpointing") else True,
|
||||||
pad_token_id=self.processing_class.pad_token_id,
|
pad_token_id=self.processing_class.pad_token_id,
|
||||||
)
|
)
|
||||||
# Set custom EOS tokens if they are specified by the model's generation
|
# Set custom EOS tokens if they are specified by the model's generation
|
||||||
|
|||||||
Reference in New Issue
Block a user