Decouple generate_during_eval from wandb to support other visualizers (#2849) [skip ci]
* Add generate_during_eval for mlflow for dpo * Decouple generate_during_eval from wandb
This commit is contained in:
@@ -28,7 +28,7 @@ class DPOStrategy:
|
|||||||
training_args_kwargs["max_completion_length"] = None
|
training_args_kwargs["max_completion_length"] = None
|
||||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||||
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
||||||
if cfg.dpo_use_weighting is not None:
|
if cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||||
if cfg.dpo_padding_free is not None:
|
if cfg.dpo_padding_free is not None:
|
||||||
|
|||||||
@@ -146,6 +146,7 @@ class AxolotlInputConfig(
|
|||||||
dpo_label_smoothing: float | None = None
|
dpo_label_smoothing: float | None = None
|
||||||
dpo_norm_loss: bool | None = None
|
dpo_norm_loss: bool | None = None
|
||||||
dpo_padding_free: bool | None = None
|
dpo_padding_free: bool | None = None
|
||||||
|
dpo_generate_during_eval: bool | None = None
|
||||||
|
|
||||||
datasets: (
|
datasets: (
|
||||||
Annotated[
|
Annotated[
|
||||||
|
|||||||
Reference in New Issue
Block a user