update to be deprecated evaluation_strategy (#1682) [skip ci]
* update to be deprecated evaluation_strategy and c4 dataset * chore: lint * remap eval strategy to new config and add tests
This commit is contained in:
@@ -1416,17 +1416,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||
training_arguments_kwargs["eval_strategy"] = "no"
|
||||
elif self.cfg.eval_steps:
|
||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||
training_arguments_kwargs["eval_strategy"] = "steps"
|
||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
elif self.cfg.evaluation_strategy:
|
||||
training_arguments_kwargs[
|
||||
"evaluation_strategy"
|
||||
] = self.cfg.evaluation_strategy
|
||||
elif self.cfg.eval_strategy:
|
||||
training_arguments_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||
else:
|
||||
# we have an eval set, but no steps defined, default to use epoch
|
||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||
training_arguments_kwargs["eval_strategy"] = "epoch"
|
||||
|
||||
if self.cfg.save_steps:
|
||||
training_arguments_kwargs["save_strategy"] = "steps"
|
||||
@@ -1860,10 +1858,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||
|
||||
if self.eval_dataset:
|
||||
training_args_kwargs["evaluation_strategy"] = "steps"
|
||||
training_args_kwargs["eval_strategy"] = "steps"
|
||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
else:
|
||||
training_args_kwargs["evaluation_strategy"] = "no"
|
||||
training_args_kwargs["eval_strategy"] = "no"
|
||||
|
||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||
training_args_kwargs["bf16"] = True
|
||||
|
||||
@@ -64,10 +64,7 @@ class EvalFirstStepCallback(
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if (
|
||||
args.evaluation_strategy == IntervalStrategy.STEPS
|
||||
and state.global_step == 1
|
||||
):
|
||||
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
||||
control.should_evaluate = True
|
||||
return control
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
"""Module for working with config dicts"""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -10,7 +8,6 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.integrations.config import merge_input_args
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.config.models.input.v0_4_1 import SUPPORTED_METRICS
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
)
|
||||
@@ -247,370 +244,3 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||
return DictDefault(
|
||||
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
||||
)
|
||||
|
||||
|
||||
def legacy_validate_config(cfg):
|
||||
"""
|
||||
This is a "pre-validation" step that handles the yaml configuration before we have any
|
||||
information about the model architecture
|
||||
"""
|
||||
if is_torch_bf16_gpu_available():
|
||||
if not cfg.bf16 and not cfg.bfloat16:
|
||||
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
||||
else:
|
||||
if (
|
||||
not cfg.merge_lora
|
||||
and not cfg.is_preprocess
|
||||
and (cfg.bf16 is True or cfg.bfloat16 is True)
|
||||
):
|
||||
raise ValueError(
|
||||
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
||||
)
|
||||
if (
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
not (cfg.bf16 or cfg.bfloat16)
|
||||
and (cfg.fp16 or cfg.float16)
|
||||
and not cfg.adapter
|
||||
and not cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
LOG.warning(
|
||||
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
|
||||
)
|
||||
# ValueError: Attempting to unscale FP16 gradients.
|
||||
# OR
|
||||
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
|
||||
if cfg.max_packed_sequence_len:
|
||||
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
||||
|
||||
if cfg.sample_packing and cfg.rl:
|
||||
raise ValueError("`sample_packing: true` does not work with RLHF training")
|
||||
|
||||
if cfg.sample_packing and not cfg.pad_to_sequence_len:
|
||||
LOG.warning(
|
||||
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||
)
|
||||
|
||||
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
||||
raise ValueError(
|
||||
"please set only one of gradient_accumulation_steps or batch_size"
|
||||
)
|
||||
if cfg.batch_size:
|
||||
LOG.warning(
|
||||
"%s\n%s",
|
||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||
)
|
||||
if (
|
||||
cfg.eval_batch_size
|
||||
and cfg.micro_batch_size
|
||||
and cfg.eval_batch_size != cfg.micro_batch_size
|
||||
):
|
||||
LOG.warning(
|
||||
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
||||
)
|
||||
|
||||
if cfg.adapter == "qlora":
|
||||
if cfg.merge_lora:
|
||||
# can't merge qlora if loaded in 8bit or 4bit
|
||||
if cfg.load_in_8bit:
|
||||
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||
|
||||
if cfg.gptq:
|
||||
raise ValueError("Can't merge qlora if gptq")
|
||||
|
||||
if cfg.load_in_4bit:
|
||||
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||
|
||||
else:
|
||||
if cfg.load_in_8bit:
|
||||
raise ValueError("Can't load qlora in 8bit")
|
||||
|
||||
if cfg.gptq:
|
||||
raise ValueError("Can't load qlora if gptq")
|
||||
|
||||
if not cfg.load_in_4bit:
|
||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||
|
||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with QLoRA")
|
||||
|
||||
loftq = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
||||
if not cfg.load_in_8bit and cfg.adapter == "lora" and not loftq:
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||
|
||||
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
||||
raise ValueError("Fused modules are not supported with LoRA")
|
||||
|
||||
if cfg.adapter and cfg.peft_layers_to_transform and cfg.unfrozen_parameters:
|
||||
raise ValueError(
|
||||
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
|
||||
)
|
||||
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter not in ("lora", "qlora"):
|
||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||
|
||||
if cfg.fsdp:
|
||||
raise ValueError("fsdp not supported with ReLoRA")
|
||||
|
||||
if cfg.deepspeed:
|
||||
raise ValueError("deepspeed not supported with ReLoRA")
|
||||
|
||||
if cfg.lr_scheduler == "one_cycle":
|
||||
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
||||
|
||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with ReLoRA")
|
||||
|
||||
if cfg.trust_remote_code:
|
||||
LOG.warning(
|
||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||
)
|
||||
|
||||
if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True:
|
||||
raise ValueError(
|
||||
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
||||
)
|
||||
|
||||
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
||||
raise ValueError("FSDP is not supported for falcon models")
|
||||
|
||||
if (
|
||||
cfg.base_model and "mpt" in cfg.base_model.lower()
|
||||
) and cfg.gradient_checkpointing:
|
||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||
|
||||
if cfg.flash_optimum is True:
|
||||
if cfg.adapter:
|
||||
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
|
||||
if cfg.fp16 or cfg.bf16:
|
||||
raise ValueError("AMP is not supported with BetterTransformer")
|
||||
if cfg.float16 is not True and cfg.bfloat16 is not True:
|
||||
LOG.warning(
|
||||
"You should probably set bfloat16 or float16 to true to "
|
||||
"load the model in float16 for BetterTransformers"
|
||||
)
|
||||
if int(torch.__version__.split(".", maxsplit=1)[0]) < 2:
|
||||
LOG.warning("torch>=2.0.0 required")
|
||||
raise ValueError(
|
||||
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
||||
)
|
||||
|
||||
if cfg.pretraining_dataset and cfg.group_by_length:
|
||||
LOG.warning(
|
||||
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
||||
)
|
||||
if cfg.pretraining_dataset and not cfg.max_steps:
|
||||
raise ValueError(
|
||||
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
|
||||
)
|
||||
|
||||
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
||||
not cfg.optimizer or "adamw" not in cfg.optimizer
|
||||
):
|
||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||
|
||||
if cfg.push_to_hub_model_id:
|
||||
raise ValueError(
|
||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||
)
|
||||
|
||||
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
|
||||
LOG.warning(
|
||||
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
|
||||
)
|
||||
|
||||
if cfg.gptq and cfg.revision_of_model:
|
||||
raise ValueError(
|
||||
"revision_of_model is not supported for GPTQ models. "
|
||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
||||
+ "point to its path, and remove revision_of_model from the config."
|
||||
)
|
||||
|
||||
# if cfg.sample_packing and cfg.sdp_attention:
|
||||
# # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||
# raise ValueError(
|
||||
# "sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||
# )
|
||||
|
||||
if cfg.sample_packing and cfg.xformers_attention:
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16):
|
||||
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
|
||||
LOG.warning(
|
||||
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
|
||||
"This may work on H100s."
|
||||
)
|
||||
|
||||
if cfg.early_stopping_patience:
|
||||
if not cfg.save_steps or not cfg.eval_steps:
|
||||
raise ValueError(
|
||||
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
|
||||
)
|
||||
if cfg.save_steps % cfg.eval_steps != 0:
|
||||
raise ValueError(
|
||||
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
|
||||
)
|
||||
|
||||
if cfg.saves_per_epoch and cfg.save_steps:
|
||||
raise ValueError(
|
||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
||||
)
|
||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||
)
|
||||
if cfg.evals_per_epoch and cfg.eval_steps:
|
||||
raise ValueError(
|
||||
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
if (
|
||||
cfg.evals_per_epoch
|
||||
and cfg.evaluation_strategy
|
||||
and cfg.evaluation_strategy != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
)
|
||||
if (
|
||||
cfg.evaluation_strategy
|
||||
and cfg.eval_steps
|
||||
and cfg.evaluation_strategy != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.val_set_size == 0
|
||||
and (cfg.eval_steps or cfg.evaluation_strategy)
|
||||
and not cfg.test_datasets
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.sample_packing
|
||||
and cfg.eval_table_size
|
||||
and cfg.eval_sample_packing is not False
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
|
||||
)
|
||||
|
||||
if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
|
||||
raise ValueError(
|
||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
|
||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||
)
|
||||
|
||||
if cfg.rope_scaling:
|
||||
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
||||
|
||||
if cfg.wandb_run_id and not cfg.wandb_name:
|
||||
cfg.wandb_name = cfg.wandb_run_id
|
||||
|
||||
LOG.warning(
|
||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||
)
|
||||
|
||||
if cfg.noisy_embedding_alpha is not None:
|
||||
# Deprecated, use neftune_noise_alpha
|
||||
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
||||
if cfg.neftune_noise_alpha is None:
|
||||
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
|
||||
else:
|
||||
# User is providing both; bail and have them sort out their settings
|
||||
raise ValueError(
|
||||
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
|
||||
)
|
||||
|
||||
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
||||
raise ValueError("neftune_noise_alpha must be > 0.0")
|
||||
|
||||
if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
|
||||
raise ValueError(
|
||||
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.unfrozen_parameters
|
||||
and cfg.gradient_checkpointing_kwargs
|
||||
and cfg.gradient_checkpointing_kwargs.use_reentrant is True
|
||||
):
|
||||
# https://github.com/huggingface/transformers/issues/21381
|
||||
raise ValueError(
|
||||
"`use_reentrant` must be false when used with partially frozen model."
|
||||
)
|
||||
|
||||
if cfg.deepspeed and Path(cfg.deepspeed).is_file():
|
||||
with open(cfg.deepspeed, encoding="utf-8") as file:
|
||||
contents = file.read()
|
||||
deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
|
||||
if cfg.flash_attention:
|
||||
if (
|
||||
deepspeed_cfg.zero_optimization
|
||||
and deepspeed_cfg.zero_optimization.stage == 3
|
||||
):
|
||||
if not (
|
||||
(
|
||||
deepspeed_cfg.bf16
|
||||
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
|
||||
is True
|
||||
)
|
||||
or (
|
||||
deepspeed_cfg.fp16
|
||||
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
|
||||
is True
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
|
||||
)
|
||||
if "8bit" in cfg.optimizer and deepspeed_cfg.optimizer:
|
||||
LOG.warning(
|
||||
f"conflicting optimizer: {cfg.optimizer} used alongside deepspeed optimizer."
|
||||
)
|
||||
|
||||
if cfg.test_datasets and cfg.val_set_size:
|
||||
raise ValueError(
|
||||
"non-zero val_set_size should not be used with test_datasets configuration"
|
||||
)
|
||||
|
||||
if cfg.fsdp and "bnb" in cfg.optimizer:
|
||||
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
|
||||
|
||||
if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
|
||||
raise ValueError(
|
||||
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
|
||||
)
|
||||
|
||||
if cfg.eval_causal_lm_metrics:
|
||||
if not isinstance(cfg.eval_causal_lm_metrics, list):
|
||||
raise ValueError("eval_causal_lm_metrics must be a list")
|
||||
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
||||
if set(cfg.eval_causal_lm_metrics) - SUPPORTED_METRICS:
|
||||
raise ValueError(
|
||||
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
|
||||
)
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
# no 8bit adaAmw w bf16
|
||||
|
||||
# GPT-NeoX
|
||||
# evals broken when extending context len
|
||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
|
||||
# attention_mask = causal_mask + attention_mask
|
||||
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3
|
||||
|
||||
@@ -68,6 +68,7 @@ class DeprecatedParameters(BaseModel):
|
||||
rope_scaling: Optional[Any] = None
|
||||
noisy_embedding_alpha: Optional[float] = None
|
||||
dpo_beta: Optional[float] = None
|
||||
evaluation_strategy: Optional[str] = None
|
||||
|
||||
@field_validator("max_packed_sequence_len")
|
||||
@classmethod
|
||||
@@ -99,6 +100,13 @@ class DeprecatedParameters(BaseModel):
|
||||
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
||||
return dpo_beta
|
||||
|
||||
@field_validator("evaluation_strategy")
|
||||
@classmethod
|
||||
def validate_evaluation_strategy(cls, evaluation_strategy):
|
||||
if evaluation_strategy is not None:
|
||||
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
|
||||
return evaluation_strategy
|
||||
|
||||
|
||||
class RemappedParameters(BaseModel):
|
||||
"""parameters that have been remapped to other names"""
|
||||
@@ -731,7 +739,7 @@ class AxolotlInputConfig(
|
||||
warmup_ratio: Optional[float] = None
|
||||
eval_steps: Optional[Union[int, float]] = None
|
||||
evals_per_epoch: Optional[Union[int]] = None
|
||||
evaluation_strategy: Optional[str] = None
|
||||
eval_strategy: Optional[str] = None
|
||||
save_steps: Optional[Union[int, float]] = None
|
||||
saves_per_epoch: Optional[int] = None
|
||||
save_strategy: Optional[str] = None
|
||||
@@ -1033,21 +1041,21 @@ class AxolotlInputConfig(
|
||||
@classmethod
|
||||
def check_evals(cls, data):
|
||||
if (
|
||||
data.get("evaluation_strategy")
|
||||
data.get("eval_strategy")
|
||||
and data.get("eval_steps")
|
||||
and data.get("evaluation_strategy") != "steps"
|
||||
and data.get("eval_strategy") != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
||||
"eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps."
|
||||
)
|
||||
|
||||
if (
|
||||
data.get("val_set_size") == 0
|
||||
and (data.get("eval_steps") or data.get("evaluation_strategy"))
|
||||
and (data.get("eval_steps") or data.get("eval_strategy"))
|
||||
and not data.get("test_datasets")
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||
"eval_steps and eval_strategy are not supported with val_set_size == 0"
|
||||
)
|
||||
if data.get("evals_per_epoch") and data.get("eval_steps"):
|
||||
raise ValueError(
|
||||
@@ -1055,11 +1063,11 @@ class AxolotlInputConfig(
|
||||
)
|
||||
if (
|
||||
data.get("evals_per_epoch")
|
||||
and data.get("evaluation_strategy")
|
||||
and data.get("evaluation_strategy") != "steps"
|
||||
and data.get("eval_strategy")
|
||||
and data.get("eval_strategy") != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
"eval_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
)
|
||||
|
||||
if data.get("do_bench_eval") and not (
|
||||
@@ -1319,6 +1327,19 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_eval_strategy(cls, data):
|
||||
if (
|
||||
data.get("evaluation_strategy") is not None
|
||||
and data.get("eval_strategy") is None
|
||||
):
|
||||
LOG.info(
|
||||
"explicitly setting `eval_strategy` from the `evaluation_strategy`"
|
||||
)
|
||||
data["eval_strategy"] = data.get("evaluation_strategy")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_offload_w_8bit_optimizer(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user