diff --git a/.mypy.ini b/.mypy.ini index bb9a21c65..ede9fef88 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,5 +1,5 @@ [mypy] - +plugins = pydantic.mypy exclude = venv [mypy-alpaca_lora_4bit.*] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c811a6eb3..6c5f20589 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,6 +31,7 @@ repos: additional_dependencies: [ 'types-PyYAML', + 'pydantic>=2.5.3', ] - repo: https://github.com/PyCQA/bandit rev: 1.7.5 diff --git a/README.md b/README.md index bf10d1ec8..1179529a0 100644 --- a/README.md +++ b/README.md @@ -543,7 +543,7 @@ is_mistral_derived_model: is_qwen_derived_model: # optional overrides to the base model configuration -model_config: +model_config_overrides: # RoPE Scaling https://github.com/huggingface/transformers/pull/24653 rope_scaling: type: # linear | dynamic @@ -560,8 +560,6 @@ bnb_config_kwargs: # Whether you are training a 4-bit GPTQ quantized model gptq: true -gptq_groupsize: 128 # group size -gptq_model_v1: false # v1 or v2 # This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer load_in_8bit: true @@ -819,10 +817,6 @@ cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosin # For one_cycle optim lr_div_factor: # Learning rate div factor -# For log_sweep optim -log_sweep_min_lr: -log_sweep_max_lr: - # Specify optimizer # Valid values are driven by the Transformers OptimizerNames class, see: # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134 diff --git a/requirements.txt b/requirements.txt index 6532d3999..8dce6daa7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ tokenizers==0.15.0 bitsandbytes>=0.41.1 accelerate==0.26.1 deepspeed>=0.13.1 +pydantic>=2.5.3 addict fire PyYAML>=6.0 @@ -27,7 +28,7 @@ scipy scikit-learn==1.2.2 pynvml art -fschat==0.2.34 +fschat==0.2.36 gradio==3.50.2 tensorboard diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 6b3894cb5..a15634247 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -24,11 +24,13 @@ from art import text2art from huggingface_hub import HfApi from huggingface_hub.utils import LocalTokenNotFoundError from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer +from transformers.utils import is_torch_bf16_gpu_available from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta from axolotl.utils.config import ( + GPUCapabilities, normalize_cfg_datasets, normalize_config, validate_config, @@ -328,7 +330,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): # load the config from the yaml file with open(config, encoding="utf-8") as file: cfg: DictDefault = DictDefault(yaml.safe_load(file)) - cfg.axolotl_config_path = config # if there are any options passed in the cli, if it is something that seems valid from the yaml, # then overwrite the value cfg_keys = cfg.keys() @@ -341,7 +342,21 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): else: cfg[k] = kwargs[k] - validate_config(cfg) + cfg.axolotl_config_path = config + + try: + device_props = torch.cuda.get_device_properties("cuda") + gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) + except: # pylint: disable=bare-except # noqa: E722 + gpu_version = None + + capabilities = GPUCapabilities( + bf16=is_torch_bf16_gpu_available(), + n_gpu=os.environ.get("WORLD_SIZE", 1), + compute_capability=gpu_version, + ) + + cfg = validate_config(cfg, capabilities=capabilities) prepare_optim_env(cfg) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config/__init__.py similarity index 97% rename from src/axolotl/utils/config.py rename to src/axolotl/utils/config/__init__.py index 1fc470da9..b21db3176 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config/__init__.py @@ -3,11 +3,17 @@ import json import logging import os from pathlib import Path +from typing import Optional import torch from transformers.utils import is_torch_bf16_gpu_available from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.config.models.input.v0_4_1 import ( + AxolotlConfigWCapabilities, + AxolotlInputConfig, +) +from axolotl.utils.config.models.internals import GPUCapabilities from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model_config @@ -191,7 +197,15 @@ def normalize_cfg_datasets(cfg): cfg.datasets[idx].conversation = "chatml" -def validate_config(cfg): +def validate_config(cfg: DictDefault, capabilities: Optional[GPUCapabilities] = None): + if capabilities: + return DictDefault( + dict(AxolotlConfigWCapabilities(**cfg.to_dict(), capabilities=capabilities)) + ) + return DictDefault(dict(AxolotlInputConfig(**cfg.to_dict()))) + + +def legacy_validate_config(cfg): """ This is a "pre-validation" step that handles the yaml configuration before we have any information about the model architecture @@ -480,9 +494,6 @@ def validate_config(cfg): if cfg.rope_scaling: LOG.warning("`rope_scaling` should now be be a key under `model_config`") - if cfg.warmup_steps and cfg.warmup_ratio: - raise ValueError("warmup_steps and warmup_ratio are mutually exclusive") - if cfg.wandb_run_id and not cfg.wandb_name: cfg.wandb_name = cfg.wandb_run_id diff --git a/src/axolotl/utils/config/models/__init__.py b/src/axolotl/utils/config/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/utils/config/models/input/__init__.py b/src/axolotl/utils/config/models/input/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/utils/config/models/input/next/__init__.py b/src/axolotl/utils/config/models/input/next/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py new file mode 100644 index 000000000..433c84af1 --- /dev/null +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -0,0 +1,931 @@ +""" +Module for pydantic models for configuration +""" + +import logging +import os +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, Field, conlist, field_validator, model_validator +from transformers import SchedulerType +from transformers.training_args import OptimizerNames + +from axolotl.utils.config.models.internals import GPUCapabilities + +LOG = logging.getLogger("axolotl.utils.config.models.input") + + +class DeprecatedParameters(BaseModel): + """configurations that are deprecated""" + + max_packed_sequence_len: Optional[int] = None + rope_scaling: Optional[Any] = None + noisy_embedding_alpha: Optional[float] = None + + @field_validator("max_packed_sequence_len") + @classmethod + def validate_max_packed_sequence_len(cls, max_packed_sequence_len): + if max_packed_sequence_len: + raise DeprecationWarning("`max_packed_sequence_len` is no longer supported") + return max_packed_sequence_len + + @field_validator("rope_scaling") + @classmethod + def validate_rope_scaling(cls, rope_scaling): + if rope_scaling: + raise DeprecationWarning( + "`rope_scaling` is no longer supported, it should now be be a key under `model_config`" + ) + return rope_scaling + + @field_validator("noisy_embedding_alpha") + @classmethod + def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha): + if noisy_embedding_alpha: + LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha") + return noisy_embedding_alpha + + +class PretrainingDataset(BaseModel): + """pretraining dataset configuration subset""" + + path: Optional[str] = None + + +class UserDefinedPrompterType(BaseModel): + """structure for user defined prompt types""" + + system_prompt: Optional[str] = None + system_format: Optional[str] = None + field_system: Optional[str] = None + field_instruction: Optional[str] = None + field_input: Optional[str] = None + field_output: Optional[str] = None + + format: Optional[str] = None + no_input_format: Optional[str] = None + field: Optional[str] = None + + +class SFTDataset(BaseModel): + """SFT configuration subset""" + + path: Optional[str] = None + split: Optional[str] = None + type: Optional[Union[str, UserDefinedPrompterType]] = None + shards: Optional[int] = None + conversation: Optional[str] = None + data_files: Optional[List[str]] = None + name: Optional[str] = None + ds_type: Optional[str] = None + train_on_split: Optional[str] = None + + field_human: Optional[str] = None + field_model: Optional[str] = None + + +class DPODataset(BaseModel): + """DPO configuration subset""" + + path: Optional[str] = None + split: Optional[str] = None + type: Optional[str] = None + data_files: Optional[List[str]] = None + + +class RLType(str, Enum): + """RL trainer type configuration subset""" + + dpo = "dpo" # pylint: disable=invalid-name + ipo = "ipo" # pylint: disable=invalid-name + kto_pair = "kto_pair" # pylint: disable=invalid-name + + +class ChatTemplate(str, Enum): + """Chat templates configuration subset""" + + chatml = "chatml" # pylint: disable=invalid-name + inst = "inst" # pylint: disable=invalid-name + + +class LoftQConfig(BaseModel): + """LoftQ configuration subset""" + + loftq_bits: int = Field(default=4, metadata={"help": "Quantization bits for LoftQ"}) + # loftq_iter: int = Field(default=1, metadata={"help": "Alternating iterations for LoftQ"}) + + +class PeftConfig(BaseModel): + """peftq configuration subset""" + + loftq_config: Optional[LoftQConfig] = None + + +class AutoType(str, Enum): + """auto type string configuration subset - used for bf16""" + + AUTO = "auto" + + +class SpecialTokensConfig(BaseModel): + """Special tokens configuration subset""" + + bos_token: Optional[str] = None + eos_token: Optional[str] = None + pad_token: Optional[str] = None + unk_token: Optional[str] = None + additional_special_tokens: Optional[List[str]] = None + + +class LoraConfig(BaseModel): + """Peft / LoRA configuration subset""" + + load_in_8bit: Optional[bool] = Field(default=False) + load_in_4bit: Optional[bool] = Field(default=False) + + adapter: Optional[str] = None + lora_model_dir: Optional[str] = None + lora_rank: Optional[int] = None + lora_alpha: Optional[int] = None + lora_fan_in_fan_out: Optional[bool] = None + lora_target_modules: Optional[List[str]] = None + lora_target_linear: Optional[bool] = None + lora_modules_to_save: Optional[List[str]] = None + lora_dropout: Optional[float] = None + peft_layers_to_transform: Optional[List[int]] = None + peft: Optional[PeftConfig] = None + + lora_on_cpu: Optional[bool] = None + gptq: Optional[bool] = None + bnb_config_kwargs: Optional[Dict[str, Any]] = None + + merge_lora: Optional[bool] = None + + @model_validator(mode="before") + @classmethod + def validate_adapter(cls, data): + if not data.get("adapter") and ( + data.get("load_in_8bit") or data.get("load_in_4bit") + ): + raise ValueError( + "load_in_8bit and load_in_4bit are not supported without setting an adapter." + "If you want to full finetune, please turn off load_in_8bit and load_in_4bit." + ) + return data + + @model_validator(mode="after") + def validate_qlora(self): + if self.adapter == "qlora": + if self.merge_lora: + # can't merge qlora if loaded in 8bit or 4bit + if self.load_in_8bit: + raise ValueError("Can't merge qlora if loaded in 8bit") + + if self.gptq: + raise ValueError("Can't merge qlora if gptq") + + if self.load_in_4bit: + raise ValueError("Can't merge qlora if loaded in 4bit") + + else: + if self.load_in_8bit: + raise ValueError("Can't load qlora in 8bit") + + if self.gptq: + raise ValueError("Can't load qlora if gptq") + + if not self.load_in_4bit: + raise ValueError("Require cfg.load_in_4bit to be True for qlora") + return self + + +class ReLoRAConfig(BaseModel): + """ReLoRA configuration subset""" + + relora_steps: Optional[int] = None + relora_warmup_steps: Optional[int] = None + relora_anneal_steps: Optional[int] = None + relora_prune_ratio: Optional[float] = None + relora_cpu_offload: Optional[bool] = None + + +class ModelInputConfig(BaseModel): + """model to train on configuration subset""" + + base_model: str + base_model_config: Optional[str] = None + tokenizer_config: Optional[str] = None + tokenizer_use_fast: Optional[bool] = None + tokenizer_legacy: Optional[bool] = None + tokenizer_type: Optional[str] = Field( + default=None, metadata={"help": "transformers tokenizer class"} + ) + model_type: Optional[str] = Field(default=None) + model_revision: Optional[str] = None + trust_remote_code: Optional[bool] = None + + model_config_overrides: Optional[Dict[str, Any]] = None + + @field_validator("trust_remote_code") + @classmethod + def hint_trust_remote_code(cls, trust_remote_code): + if trust_remote_code: + LOG.warning( + "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." + ) + return trust_remote_code + + +class HyperparametersConfig(BaseModel): + """training hyperparams configuration subset""" + + gradient_accumulation_steps: Optional[int] = Field(default=1) + micro_batch_size: Optional[int] = Field( + default=1, + metadata={"help": "per gpu micro batch size for training"}, + ) + batch_size: Optional[int] = Field( + default=None, + metadata={ + "help": "Total batch size, we do not recommended setting this manually" + }, + ) + eval_batch_size: Optional[int] = Field( + default=None, + metadata={ + "help": "per gpu micro batch size for evals, defaults to value of micro_batch_size" + }, + ) + + train_on_inputs: Optional[bool] = None + group_by_length: Optional[bool] = None + + learning_rate: Union[str, float] + weight_decay: Optional[float] = None + optimizer: Optional[OptimizerNames] = None + torchdistx_path: Optional[str] = None + lr_scheduler: Optional[SchedulerType] = None + lr_scheduler_kwargs: Optional[Dict[str, Any]] = None + lr_quadratic_warmup: Optional[bool] = None + cosine_min_lr_ratio: Optional[float] = None + cosine_constant_lr_ratio: Optional[float] = None + lr_div_factor: Optional[float] = None + + adam_epsilon: Optional[float] = None + adam_beta1: Optional[float] = None + adam_beta2: Optional[float] = None + max_grad_norm: Optional[float] = None + num_epochs: int = Field(default=1) + + @field_validator("batch_size") + @classmethod + def hint_batch_size_set(cls, batch_size): + if batch_size: + LOG.warning( + "%s\n%s", + "batch_size is not recommended. Please use gradient_accumulation_steps instead.", + "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.", + ) + return batch_size + + +class ModelOutputConfig(BaseModel): + """model save configuration subset""" + + output_dir: str = Field(default="./model-out") + hub_model_id: Optional[str] = None + hub_strategy: Optional[str] = None + save_safetensors: Optional[bool] = None + + +class MLFlowConfig(BaseModel): + """mlflow configuration subset""" + + use_mlflow: Optional[str] = None + mlflow_tracking_uri: Optional[str] = None + mlflow_experiment_name: Optional[str] = None + + +class WandbConfig(BaseModel): + """wandb configuration subset""" + + use_wandb: Optional[bool] = None + wandb_name: Optional[str] = None + wandb_run_id: Optional[str] = None + wandb_mode: Optional[str] = None + wandb_project: Optional[str] = None + wandb_entity: Optional[str] = None + wandb_watch: Optional[str] = None + wandb_log_model: Optional[str] = None + + @model_validator(mode="before") + @classmethod + def check_wandb_run(cls, data): + if data.get("wandb_run_id") and not data.get("wandb_name"): + data["wandb_name"] = data.get("wandb_run_id") + + LOG.warning( + "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead." + ) + + return data + + +# pylint: disable=too-many-public-methods,too-many-ancestors +class AxolotlInputConfig( + ModelInputConfig, + LoraConfig, + ReLoRAConfig, + HyperparametersConfig, + WandbConfig, + MLFlowConfig, + DeprecatedParameters, + BaseModel, +): + """wrapper of all config options""" + + strict: Optional[bool] = Field(default=False) + resume_from_checkpoint: Optional[str] = None + auto_resume_from_checkpoints: Optional[bool] = None + resize_token_embeddings_to_32x: Optional[bool] = None + + rl: Optional[RLType] = None + + datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore + dataset_prepared_path: Optional[str] = None + dataset_shard_num: Optional[int] = None + dataset_shard_idx: Optional[int] = None + + pretraining_dataset: Optional[ # type: ignore + conlist(Union[SFTDataset, PretrainingDataset], min_length=1) + ] = Field( + default=None, metadata={"help": {"streaming dataset to use for pretraining"}} + ) + dataset_processes: Optional[int] = Field(default=os.cpu_count()) + dataset_keep_in_memory: Optional[bool] = None + dataloader_pin_memory: Optional[bool] = None + dataloader_num_workers: Optional[int] = None + dataloader_prefetch_factor: Optional[int] = None + dataloader_drop_last: Optional[bool] = None + + push_dataset_to_hub: Optional[str] = None + hf_use_auth_token: Optional[bool] = None + + device: Optional[Any] = None + device_map: Optional[Any] = None + world_size: Optional[int] = None + local_rank: Optional[int] = None + ddp: Optional[bool] = None + + seed: Optional[int] = None + ddp_timeout: Optional[int] = None + ddp_bucket_cap_mb: Optional[int] = None + ddp_broadcast_buffers: Optional[bool] = None + ddp_find_unused_parameters: Optional[bool] = None + + eval_table_size: Optional[int] = None + eval_max_new_tokens: Optional[int] = None + do_causal_lm_eval: Optional[bool] = None + eval_causal_lm_metrics: Optional[List[str]] = None + do_bench_eval: Optional[bool] = None + bench_dataset: Optional[str] = None + metric_for_best_model: Optional[str] = None + greater_is_better: Optional[bool] = None + + loss_watchdog_threshold: Optional[float] = None + loss_watchdog_patience: Optional[int] = None + + bf16: Optional[Union[AutoType, bool]] = AutoType.AUTO + fp16: Optional[bool] = None + bfloat16: Optional[bool] = None # for non-AMP cases + float16: Optional[bool] = None # for non-AMP cases + tf32: Optional[bool] = None + float32: Optional[bool] = None + + # torch_dtype: Optional[torch.dtype] + + gradient_checkpointing: Optional[bool] = Field(default=False) + gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None + + unfrozen_parameters: Optional[List[str]] = None + + sequence_len: int = Field(default=1024) + sample_packing: Optional[bool] = None + eval_sample_packing: Optional[bool] = None + pad_to_sequence_len: Optional[bool] = None + + xformers_attention: Optional[bool] = None + sdp_attention: Optional[bool] = None + s2_attention: Optional[bool] = None + flash_attention: Optional[bool] = None + flash_attn_cross_entropy: Optional[bool] = None + flash_attn_rms_norm: Optional[bool] = None + flash_attn_fuse_qkv: Optional[bool] = None + flash_attn_fuse_mlp: Optional[bool] = None + flash_optimum: Optional[bool] = None + + deepspeed: Optional[Union[str, Dict[str, Any]]] = None + fsdp: Optional[List[str]] = None + fsdp_config: Optional[Dict[str, Any]] = None + + val_set_size: Optional[float] = Field(default=0.0) + + special_tokens: Optional[SpecialTokensConfig] = None + tokens: Optional[List[str]] = None + + torch_compile: Optional[bool] = None + torch_compile_backend: Optional[str] = None + + max_steps: Optional[int] = None + warmup_steps: Optional[int] = None + warmup_ratio: Optional[float] = None + eval_steps: Optional[int] = None + evaluation_strategy: Optional[str] = None + save_steps: Optional[int] = None + saves_per_epoch: Optional[int] = None + save_strategy: Optional[str] = None + save_total_limit: Optional[int] = None + logging_steps: Optional[int] = None + early_stopping_patience: Optional[int] = None + + neftune_noise_alpha: Optional[float] = None + + max_memory: Optional[Union[int, str]] = None + gpu_memory_limit: Optional[Union[int, str]] = None + + chat_template: Optional[Union[Literal["chatml", "inst"], ChatTemplate]] = None + default_system_message: Optional[str] = None + + # INTERNALS - document for now, generally not set externally + is_preprocess: Optional[bool] = None + + total_num_tokens: Optional[int] = None + total_supervised_tokens: Optional[int] = None + sample_packing_eff_est: Optional[float] = None + axolotl_config_path: Optional[str] = None + + is_falcon_derived_model: Optional[bool] = Field(default=False) + is_llama_derived_model: Optional[bool] = Field(default=False) + is_mistral_derived_model: Optional[bool] = Field(default=False) + is_qwen_derived_model: Optional[bool] = Field(default=False) + + @field_validator("datasets", mode="before") + @classmethod + def fix_sharegpt_datasets(cls, datasets): + for idx, ds_cfg in enumerate(datasets): + if not ds_cfg["type"]: + continue + if ds_cfg["type"] == "sharegpt:chat": + LOG.warning( + PendingDeprecationWarning( + "`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead." + ) + ) + datasets[idx]["type"] = "sharegpt" + if "sharegpt_simple" in ds_cfg["type"]: + LOG.warning( + PendingDeprecationWarning( + "`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead." + ) + ) + datasets[idx]["type"] = datasets[idx]["type"].replace( + "sharegpt_simple", "sharegpt" + ) + return datasets + + @model_validator(mode="before") + @classmethod + def check_batch_size_fields(cls, data): + fields = ("micro_batch_size", "gradient_accumulation_steps", "batch_size") + non_empty_count = sum(1 for field in fields if data.get(field)) + + if non_empty_count < 2: + raise ValueError(f"At least two of {', '.join(fields)} must be set") + return data + + @model_validator(mode="before") + @classmethod + def check_pretraining_w_max_steps(cls, data): + if data.get("pretraining_dataset") and not data.get("max_steps"): + raise ValueError( + "max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_pretraining_w_group_by_length(cls, data): + if data.get("pretraining_dataset") and data.get("group_by_length"): + LOG.warning( + "You probably want to disable group_by_length as it will force a streamed dataset to download completely." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_gptq_w_revision(cls, data): + if data.get("gptq") and data.get("model_revision"): + raise ValueError( + "model_revision is not supported for GPTQ models. " + + "Please download the model from HuggingFace Hub manually for correct branch, " + + "point to its path, and remove model_revision from the config." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_sample_packing_w_xformers(cls, data): + if data.get("sample_packing") and data.get("xformers_attention"): + raise ValueError( + "sample_packing not compatible with xformers_attention. Use flash_attention" + ) + + return data + + @model_validator(mode="before") + @classmethod + def check_sample_packing_w_rl(cls, data): + if data.get("sample_packing") and data.get("rl"): + raise ValueError("`sample_packing: true` does not work with RLHF training") + return data + + @model_validator(mode="before") + @classmethod + def hint_sample_packing_padding(cls, data): + if data.get("sample_packing") and not data.get("pad_to_sequence_len"): + LOG.warning( + "`pad_to_sequence_len: true` is recommended when using sample_packing" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_gas_bsz(cls, data): + if data.get("gradient_accumulation_steps") and data.get("batch_size"): + raise ValueError( + "please set only one of gradient_accumulation_steps or batch_size" + ) + return data + + @model_validator(mode="before") + @classmethod + def hint_eval_train_mbsz(cls, data): + if ( + data.get("eval_batch_size") + and data.get("micro_batch_size") + and data.get("eval_batch_size") != data.get("micro_batch_size") + ): + LOG.warning( + "eval_batch_size != micro_batch_size. This can lead to VRAM instability." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_push_ds_auth(cls, data): + if ( + data.get("push_dataset_to_hub") + and data.get("hf_use_auth_token") is not True + ): + raise ValueError( + "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" + ) + return data + + @model_validator(mode="after") + def check_falcon_fsdp(self): + if (self.base_model and "falcon" in self.base_model.lower()) and self.fsdp: + raise ValueError("FSDP is not supported for falcon models") + return self + + @model_validator(mode="after") + def check_mpt_checkpointing(self): + if ( + self.base_model and "mpt" in self.base_model.lower() + ) and self.gradient_checkpointing: + raise ValueError("gradient_checkpointing is not supported for MPT models") + return self + + @model_validator(mode="after") + def check_better_transformers(self): + if self.flash_optimum is True: + if self.adapter: + LOG.warning( + "BetterTransformers probably doesn't work with PEFT adapters" + ) + if self.fp16 or self.bf16: + raise ValueError("AMP is not supported with BetterTransformer") + if self.float16 is not True and self.bfloat16 is not True: + LOG.warning( + "You should probably set bfloat16 or float16 to true to " + "load the model in float16 for BetterTransformers" + ) + return self + + @model_validator(mode="after") + def check_adamw_optimizer_params(self): + if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and ( + not self.optimizer or "adamw" not in self.optimizer.value + ): + LOG.warning("adamw hyperparameters found, but no adamw optimizer set") + return self + + @model_validator(mode="before") + @classmethod + def check_saves(cls, data): + if ( + data.get("save_strategy") + and data.get("save_steps") + and data.get("save_strategy") != "steps" + ): + raise ValueError( + "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." + ) + if data.get("saves_per_epoch") and data.get("save_steps"): + raise ValueError( + "save_steps and saves_per_epoch are mutually exclusive and cannot be used together." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_push_save(cls, data): + if data.get("hub_model_id") and not ( + data.get("save_steps") or data.get("saves_per_epoch") + ): + LOG.warning( + "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_evals(cls, data): + if ( + data.get("evaluation_strategy") + and data.get("eval_steps") + and data.get("evaluation_strategy") != "steps" + ): + raise ValueError( + "evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps." + ) + + if ( + data.get("val_set_size") == 0 + and (data.get("eval_steps") or data.get("evaluation_strategy")) + and not data.get("test_datasets") + ): + raise ValueError( + "eval_steps and evaluation_strategy are not supported with val_set_size == 0" + ) + if data.get("evals_per_epoch") and data.get("eval_steps"): + raise ValueError( + "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together." + ) + if ( + data.get("evals_per_epoch") + and data.get("evaluation_strategy") + and data.get("evaluation_strategy") != "steps" + ): + raise ValueError( + "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch." + ) + + return data + + @model_validator(mode="before") + @classmethod + def check_eval_packing(cls, data): + if ( + data.get("sample_packing") + and data.get("eval_table_size") + and data.get("eval_sample_packing") is not False + ): + raise ValueError( + "eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_warmup(cls, data): + if data.get("warmup_steps") and data.get("warmup_ratio"): + raise ValueError("warmup_steps and warmup_ratio are mutually exclusive") + return data + + @model_validator(mode="before") + @classmethod + def check_neftune(cls, data): + if data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"): + data["neftune_noise_alpha"] = data["noisy_embedding_alpha"] + del data["noisy_embedding_alpha"] + elif data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"): + raise ValueError( + "noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting" + ) + return data + + @field_validator("neftune_noise_alpha") + @classmethod + def validate_neftune_noise_alpha(cls, neftune_noise_alpha): + if neftune_noise_alpha is not None and neftune_noise_alpha <= 0.0: + raise ValueError("neftune_noise_alpha must be > 0.0") + return neftune_noise_alpha + + @model_validator(mode="before") + @classmethod + def check_frozen(cls, data): + if ( + data.get("adapter") + and data.get("peft_layers_to_transform") + and data.get("unfrozen_parameters") + ): + raise ValueError( + "`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior." + ) + + return data + + @model_validator(mode="after") + def check_fft_possible_bad_config(self): + if ( + # pylint: disable=too-many-boolean-expressions + not (self.bf16 or self.bfloat16) + and (self.fp16 or self.float16) + and not self.adapter + and not self.flash_attention + and self.sample_packing + ): + LOG.warning( + "Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA." + ) + # ValueError: Attempting to unscale FP16 gradients. + # OR + # RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half + return self + + @model_validator(mode="after") + def check_fused_lora(self): + if self.adapter in ["lora", "qlora"] and ( + self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp + ): + raise ValueError("Fused modules are not supported with LoRA/QLoRA") + return self + + @model_validator(mode="after") + def hint_lora_8bit(self): + loftq = ( + self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits + ) + if not self.load_in_8bit and self.adapter == "lora" and not loftq: + LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") + return self + + @model_validator(mode="after") + def check_early_stopping(self): + if self.early_stopping_patience: + if not self.save_steps or self.eval_steps: + raise ValueError( + "`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps." + ) + if self.save_steps % self.eval_steps != 0: + raise ValueError( + "`early_stopping_patience` requires that eval_steps should evenly divide save_steps." + ) + return self + + @model_validator(mode="after") + def check_relora(self): + if self.relora_steps: + if self.adapter not in ("lora", "qlora"): + raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") + + if self.fsdp: + raise ValueError("fsdp not supported with ReLoRA") + + if self.deepspeed: + raise ValueError("deepspeed not supported with ReLoRA") + + if self.lr_scheduler == "one_cycle": + raise ValueError( + "ReLoRA is not compatible with the one_cycle scheduler" + ) + + if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp: + raise ValueError("Fused modules are not supported with ReLoRA") + return self + + @model_validator(mode="before") + @classmethod + def check_mem_mismatch(cls, data): + if ( + data.get("max_memory") is not None + and data.get("gpu_memory_limit") is not None + ): + raise ValueError( + "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_use_reentrant_mismatch(cls, data): + if ( + data.get("unfrozen_parameters") + and data.get("gradient_checkpointing_kwargs") + and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") + is True + ): + # https://github.com/huggingface/transformers/issues/21381 + raise ValueError( + "`use_reentrant` must be false when used with partially frozen model." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_val_w_test_datasets(cls, data): + if data.get("test_datasets") and data.get("val_set_size"): + raise ValueError( + "non-zero val_set_size should not be used with test_datasets configuration" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_fsdp_w_8bit_optimizer(cls, data): + if data.get("fsdp") and "bnb" in data.get("optimizer", ""): + raise ValueError(f"FSDP not compatible with {data.get('optimizer')}") + return data + + @model_validator(mode="before") + @classmethod + def check_causal_lm_evals(cls, data): + if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"): + raise ValueError( + "do_causal_lm_eval is enabled, eval_sample_packing must be set to False" + ) + + if data.get("eval_causal_lm_metrics"): + supported_metrics = ["sacrebleu", "comet", "ter", "chrf"] + if not isinstance(data.get("eval_causal_lm_metrics"), list): + raise ValueError("eval_causal_lm_metrics must be a list") + # only ["sacrebleu", "comet", "ter", "chrf"] supported + if set(data.get("eval_causal_lm_metrics")) - set(supported_metrics): + raise ValueError( + f"eval_causal_lm_metrics must be one of {supported_metrics}" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_dataset_or_pretraining_dataset(cls, data): + if data.get("datasets") is None and data.get("pretraining_dataset") is None: + raise ValueError("either datasets or pretraining_dataset is required") + return data + + +class AxolotlConfigWCapabilities(AxolotlInputConfig): + """wrapper to valdiate gpu capabilities with the configured options""" + + capabilities: GPUCapabilities + + @model_validator(mode="after") + def check_bf16(self): + if self.capabilities.bf16: + if not self.bf16 and not self.bfloat16: + LOG.info( + "bf16 support detected, but not enabled for this configuration." + ) + else: + if ( + not self.merge_lora + and not self.is_preprocess + and (self.bf16 is True or self.bfloat16 is True) + ): + raise ValueError( + "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above." + ) + return self + + @model_validator(mode="before") + @classmethod + def check_sample_packing_w_sdpa_bf16(cls, data): + is_sm_90: bool = ( + data["capabilities"] + and data["capabilities"].get("compute_capability") == "sm_90" + ) + if ( + data.get("sample_packing") + and data.get("sdp_attention") + and (data.get("bfloat16") or data.get("bf16")) + and not is_sm_90 + ): + # https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450 + LOG.warning( + "sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. " + "This may work on H100s." + ) + + return data diff --git a/src/axolotl/utils/config/models/internals/__init__.py b/src/axolotl/utils/config/models/internals/__init__.py new file mode 100644 index 000000000..dd742caf4 --- /dev/null +++ b/src/axolotl/utils/config/models/internals/__init__.py @@ -0,0 +1,14 @@ +"""module for gpu capabilities""" +from typing import Optional + +from pydantic import BaseModel, Field + + +class GPUCapabilities(BaseModel): + """model to manage the gpu capabilities statically""" + + bf16: bool = Field(default=False) + fp8: bool = Field(default=False) + n_gpu: int = Field(default=1) + n_node: int = Field(default=1) + compute_capability: Optional[str] = Field(default=None) diff --git a/src/axolotl/utils/dict.py b/src/axolotl/utils/dict.py index 69567c604..409d088e6 100644 --- a/src/axolotl/utils/dict.py +++ b/src/axolotl/utils/dict.py @@ -12,4 +12,4 @@ class DictDefault(Dict): return None def __or__(self, other): - return DictDefault(super().__or__(other)) + return DictDefault(super().__ror__(other)) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c2006997d..c94908f3d 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -104,8 +104,8 @@ def load_model_config(cfg): ) raise err - if cfg.model_config: - for key, val in cfg.model_config.items(): + if cfg.model_config_overrides: + for key, val in cfg.model_config_overrides.items(): setattr(model_config, key, val) check_model_config(cfg, model_config) diff --git a/tests/test_dict.py b/tests/test_dict.py index 8367e7c2a..2007cb085 100644 --- a/tests/test_dict.py +++ b/tests/test_dict.py @@ -39,7 +39,9 @@ class DictDefaultTest(unittest.TestCase): ), "DictDefault should support in operator for existing keys in list" def test_dict_or_operator(self): - cfg = DictDefault( + cfg = DictDefault({"key_a": {"key_b": "value_b"}, "key_f": "value_g"}) + + cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation { "key_a": {"key_b": "value_a"}, "key_c": "value_c", @@ -48,10 +50,6 @@ class DictDefaultTest(unittest.TestCase): } ) - cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation - {"key_a": {"key_b": "value_b"}, "key_f": "value_g"} - ) - assert ( cfg.key_a.key_b == "value_b" ), "DictDefault should support OR operator for existing nested keys" diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index cea39d0ad..cf662d95f 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -204,13 +204,13 @@ class TestPromptTokenizationStrategies(unittest.TestCase): # fmt: off # System message, multi-turn conversations mt_ids = tokenize(test_data['multi_turn_sys']) - assert decode(mt_ids) == ' [INST] lorem\nabc [/INST] ipsum [INST] 123 [/INST] sit' - assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] + assert decode(mt_ids) == ' [INST] lorem\nabc [/INST] ipsum [INST] 123 [/INST] sit' + assert mt_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] # System message, single-turn conversations st_ids = tokenize(test_data['single_turn_sys']) - assert decode(st_ids) == ' [INST] lorem\nabc [/INST] ipsum' - assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] + assert decode(st_ids) == ' [INST] lorem\nabc [/INST] ipsum' + assert st_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] # No system message, single-turn ns_ids = tokenize(test_data['single_turn_no_sys']) diff --git a/tests/test_validation.py b/tests/test_validation.py index e5a74394c..4ec544351 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,20 +1,39 @@ +# pylint: disable=too-many-lines """Module for testing the validation module""" import logging import os -import unittest from typing import Optional import pytest -from transformers.utils import is_torch_bf16_gpu_available +from pydantic import ValidationError from axolotl.utils.config import validate_config +from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities from axolotl.utils.dict import DictDefault from axolotl.utils.models import check_model_config from axolotl.utils.wandb_ import setup_wandb_env_vars -class BaseValidation(unittest.TestCase): +@pytest.fixture(name="minimal_cfg") +def fixture_cfg(): + return DictDefault( + { + "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", + "learning_rate": 0.000001, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + } + ], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + } + ) + + +class BaseValidation: """ Base validation module to setup the log capture """ @@ -27,14 +46,110 @@ class BaseValidation(unittest.TestCase): # pylint: disable=too-many-public-methods -class ValidationTest(BaseValidation): +class TestValidation(BaseValidation): """ Test the validation module """ + def test_datasets_min_length(self): + cfg = DictDefault( + { + "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", + "learning_rate": 0.000001, + "datasets": [], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + } + ) + + with pytest.raises( + ValidationError, + match=r".*List should have at least 1 item after validation*", + ): + validate_config(cfg) + + def test_datasets_min_length_empty(self): + cfg = DictDefault( + { + "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", + "learning_rate": 0.000001, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + } + ) + + with pytest.raises( + ValueError, match=r".*either datasets or pretraining_dataset is required*" + ): + validate_config(cfg) + + def test_pretrain_dataset_min_length(self): + cfg = DictDefault( + { + "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", + "learning_rate": 0.000001, + "pretraining_dataset": [], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "max_steps": 100, + } + ) + + with pytest.raises( + ValidationError, + match=r".*List should have at least 1 item after validation*", + ): + validate_config(cfg) + + def test_valid_pretrain_dataset(self): + cfg = DictDefault( + { + "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", + "learning_rate": 0.000001, + "pretraining_dataset": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + } + ], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "max_steps": 100, + } + ) + + validate_config(cfg) + + def test_valid_sft_dataset(self): + cfg = DictDefault( + { + "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", + "learning_rate": 0.000001, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + } + ], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + } + ) + + validate_config(cfg) + def test_batch_size_unused_warning(self): cfg = DictDefault( { + "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", + "learning_rate": 0.000001, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + } + ], + "micro_batch_size": 4, "batch_size": 32, } ) @@ -43,104 +158,163 @@ class ValidationTest(BaseValidation): validate_config(cfg) assert "batch_size is not recommended" in self._caplog.records[0].message - def test_qlora(self): - base_cfg = DictDefault( + def test_batch_size_more_params(self): + cfg = DictDefault( { - "adapter": "qlora", + "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", + "learning_rate": 0.000001, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + } + ], + "batch_size": 32, } ) - cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation - { - "load_in_8bit": True, - } + with pytest.raises(ValueError, match=r".*At least two of*"): + validate_config(cfg) + + def test_qlora(self, minimal_cfg): + base_cfg = ( + DictDefault( + { + "adapter": "qlora", + } + ) + | minimal_cfg + ) + + cfg = ( + DictDefault( # pylint: disable=unsupported-binary-operation + { + "load_in_8bit": True, + } + ) + | base_cfg ) with pytest.raises(ValueError, match=r".*8bit.*"): validate_config(cfg) - cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation - { - "gptq": True, - } + cfg = ( + DictDefault( # pylint: disable=unsupported-binary-operation + { + "gptq": True, + } + ) + | base_cfg ) with pytest.raises(ValueError, match=r".*gptq.*"): validate_config(cfg) - cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation - { - "load_in_4bit": False, - } + cfg = ( + DictDefault( # pylint: disable=unsupported-binary-operation + { + "load_in_4bit": False, + } + ) + | base_cfg ) with pytest.raises(ValueError, match=r".*4bit.*"): validate_config(cfg) - cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation - { - "load_in_4bit": True, - } + cfg = ( + DictDefault( # pylint: disable=unsupported-binary-operation + { + "load_in_4bit": True, + } + ) + | base_cfg ) validate_config(cfg) - def test_qlora_merge(self): - base_cfg = DictDefault( - { - "adapter": "qlora", - "merge_lora": True, - } + def test_qlora_merge(self, minimal_cfg): + base_cfg = ( + DictDefault( + { + "adapter": "qlora", + "merge_lora": True, + } + ) + | minimal_cfg ) - cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation - { - "load_in_8bit": True, - } + cfg = ( + DictDefault( # pylint: disable=unsupported-binary-operation + { + "load_in_8bit": True, + } + ) + | base_cfg ) with pytest.raises(ValueError, match=r".*8bit.*"): validate_config(cfg) - cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation - { - "gptq": True, - } + cfg = ( + DictDefault( # pylint: disable=unsupported-binary-operation + { + "gptq": True, + } + ) + | base_cfg ) with pytest.raises(ValueError, match=r".*gptq.*"): validate_config(cfg) - cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation - { - "load_in_4bit": True, - } + cfg = ( + DictDefault( # pylint: disable=unsupported-binary-operation + { + "load_in_4bit": True, + } + ) + | base_cfg ) with pytest.raises(ValueError, match=r".*4bit.*"): validate_config(cfg) - def test_hf_use_auth_token(self): - cfg = DictDefault( - { - "push_dataset_to_hub": "namespace/repo", - } + def test_hf_use_auth_token(self, minimal_cfg): + cfg = ( + DictDefault( + { + "push_dataset_to_hub": "namespace/repo", + } + ) + | minimal_cfg ) with pytest.raises(ValueError, match=r".*hf_use_auth_token.*"): validate_config(cfg) - cfg = DictDefault( - { - "push_dataset_to_hub": "namespace/repo", - "hf_use_auth_token": True, - } + cfg = ( + DictDefault( + { + "push_dataset_to_hub": "namespace/repo", + "hf_use_auth_token": True, + } + ) + | minimal_cfg ) validate_config(cfg) def test_gradient_accumulations_or_batch_size(self): cfg = DictDefault( { + "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", + "learning_rate": 0.000001, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + } + ], "gradient_accumulation_steps": 1, "batch_size": 1, } @@ -151,75 +325,75 @@ class ValidationTest(BaseValidation): ): validate_config(cfg) - cfg = DictDefault( - { - "batch_size": 1, - } - ) - - validate_config(cfg) - - cfg = DictDefault( - { - "gradient_accumulation_steps": 1, - } - ) - - validate_config(cfg) - - def test_falcon_fsdp(self): + def test_falcon_fsdp(self, minimal_cfg): regex_exp = r".*FSDP is not supported for falcon models.*" # Check for lower-case - cfg = DictDefault( - { - "base_model": "tiiuae/falcon-7b", - "fsdp": ["full_shard", "auto_wrap"], - } + cfg = ( + DictDefault( + { + "base_model": "tiiuae/falcon-7b", + "fsdp": ["full_shard", "auto_wrap"], + } + ) + | minimal_cfg ) with pytest.raises(ValueError, match=regex_exp): validate_config(cfg) # Check for upper-case - cfg = DictDefault( - { - "base_model": "Falcon-7b", - "fsdp": ["full_shard", "auto_wrap"], - } + cfg = ( + DictDefault( + { + "base_model": "Falcon-7b", + "fsdp": ["full_shard", "auto_wrap"], + } + ) + | minimal_cfg ) with pytest.raises(ValueError, match=regex_exp): validate_config(cfg) - cfg = DictDefault( - { - "base_model": "tiiuae/falcon-7b", - } + cfg = ( + DictDefault( + { + "base_model": "tiiuae/falcon-7b", + } + ) + | minimal_cfg ) validate_config(cfg) - def test_mpt_gradient_checkpointing(self): + def test_mpt_gradient_checkpointing(self, minimal_cfg): regex_exp = r".*gradient_checkpointing is not supported for MPT models*" # Check for lower-case - cfg = DictDefault( - { - "base_model": "mosaicml/mpt-7b", - "gradient_checkpointing": True, - } + cfg = ( + DictDefault( + { + "base_model": "mosaicml/mpt-7b", + "gradient_checkpointing": True, + } + ) + | minimal_cfg ) with pytest.raises(ValueError, match=regex_exp): validate_config(cfg) - def test_flash_optimum(self): - cfg = DictDefault( - { - "flash_optimum": True, - "adapter": "lora", - } + def test_flash_optimum(self, minimal_cfg): + cfg = ( + DictDefault( + { + "flash_optimum": True, + "adapter": "lora", + "bf16": False, + } + ) + | minimal_cfg ) with self._caplog.at_level(logging.WARNING): @@ -230,10 +404,14 @@ class ValidationTest(BaseValidation): for record in self._caplog.records ) - cfg = DictDefault( - { - "flash_optimum": True, - } + cfg = ( + DictDefault( + { + "flash_optimum": True, + "bf16": False, + } + ) + | minimal_cfg ) with self._caplog.at_level(logging.WARNING): @@ -243,34 +421,43 @@ class ValidationTest(BaseValidation): for record in self._caplog.records ) - cfg = DictDefault( - { - "flash_optimum": True, - "fp16": True, - } + cfg = ( + DictDefault( + { + "flash_optimum": True, + "fp16": True, + } + ) + | minimal_cfg ) regex_exp = r".*AMP is not supported.*" with pytest.raises(ValueError, match=regex_exp): validate_config(cfg) - cfg = DictDefault( - { - "flash_optimum": True, - "bf16": True, - } + cfg = ( + DictDefault( + { + "flash_optimum": True, + "bf16": True, + } + ) + | minimal_cfg ) regex_exp = r".*AMP is not supported.*" with pytest.raises(ValueError, match=regex_exp): validate_config(cfg) - def test_adamw_hyperparams(self): - cfg = DictDefault( - { - "optimizer": None, - "adam_epsilon": 0.0001, - } + def test_adamw_hyperparams(self, minimal_cfg): + cfg = ( + DictDefault( + { + "optimizer": None, + "adam_epsilon": 0.0001, + } + ) + | minimal_cfg ) with self._caplog.at_level(logging.WARNING): @@ -281,11 +468,14 @@ class ValidationTest(BaseValidation): for record in self._caplog.records ) - cfg = DictDefault( - { - "optimizer": "adafactor", - "adam_beta1": 0.0001, - } + cfg = ( + DictDefault( + { + "optimizer": "adafactor", + "adam_beta1": 0.0001, + } + ) + | minimal_cfg ) with self._caplog.at_level(logging.WARNING): @@ -296,30 +486,39 @@ class ValidationTest(BaseValidation): for record in self._caplog.records ) - cfg = DictDefault( - { - "optimizer": "adamw_bnb_8bit", - "adam_beta1": 0.9, - "adam_beta2": 0.99, - "adam_epsilon": 0.0001, - } + cfg = ( + DictDefault( + { + "optimizer": "adamw_bnb_8bit", + "adam_beta1": 0.9, + "adam_beta2": 0.99, + "adam_epsilon": 0.0001, + } + ) + | minimal_cfg ) validate_config(cfg) - cfg = DictDefault( - { - "optimizer": "adafactor", - } + cfg = ( + DictDefault( + { + "optimizer": "adafactor", + } + ) + | minimal_cfg ) validate_config(cfg) - def test_deprecated_packing(self): - cfg = DictDefault( - { - "max_packed_sequence_len": 1024, - } + def test_deprecated_packing(self, minimal_cfg): + cfg = ( + DictDefault( + { + "max_packed_sequence_len": 1024, + } + ) + | minimal_cfg ) with pytest.raises( DeprecationWarning, @@ -327,12 +526,15 @@ class ValidationTest(BaseValidation): ): validate_config(cfg) - def test_packing(self): - cfg = DictDefault( - { - "sample_packing": True, - "pad_to_sequence_len": None, - } + def test_packing(self, minimal_cfg): + cfg = ( + DictDefault( + { + "sample_packing": True, + "pad_to_sequence_len": None, + } + ) + | minimal_cfg ) with self._caplog.at_level(logging.WARNING): validate_config(cfg) @@ -342,62 +544,79 @@ class ValidationTest(BaseValidation): for record in self._caplog.records ) - @pytest.mark.skipif( - is_torch_bf16_gpu_available(), - reason="test should only run on gpus w/o bf16 support", - ) - def test_merge_lora_no_bf16_fail(self): + def test_merge_lora_no_bf16_fail(self, minimal_cfg): """ This is assumed to be run on a CPU machine, so bf16 is not supported. """ - cfg = DictDefault( - { - "bf16": True, - } + cfg = ( + DictDefault( + { + "bf16": True, + "capabilities": {"bf16": False}, + } + ) + | minimal_cfg ) with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"): - validate_config(cfg) + AxolotlConfigWCapabilities(**cfg.to_dict()) - cfg = DictDefault( - { - "bf16": True, - "merge_lora": True, - } + cfg = ( + DictDefault( + { + "bf16": True, + "merge_lora": True, + "capabilities": {"bf16": False}, + } + ) + | minimal_cfg ) validate_config(cfg) - def test_sharegpt_deprecation(self): - cfg = DictDefault( - {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]} + def test_sharegpt_deprecation(self, minimal_cfg): + cfg = ( + DictDefault( + {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]} + ) + | minimal_cfg ) with self._caplog.at_level(logging.WARNING): - validate_config(cfg) + new_cfg = validate_config(cfg) assert any( "`type: sharegpt:chat` will soon be deprecated." in record.message for record in self._caplog.records ) - assert cfg.datasets[0].type == "sharegpt" + assert new_cfg.datasets[0].type == "sharegpt" - cfg = DictDefault( - {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}]} + cfg = ( + DictDefault( + { + "datasets": [ + {"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"} + ] + } + ) + | minimal_cfg ) with self._caplog.at_level(logging.WARNING): - validate_config(cfg) + new_cfg = validate_config(cfg) assert any( "`type: sharegpt_simple` will soon be deprecated." in record.message for record in self._caplog.records ) - assert cfg.datasets[0].type == "sharegpt:load_role" + assert new_cfg.datasets[0].type == "sharegpt:load_role" - def test_no_conflict_save_strategy(self): - cfg = DictDefault( - { - "save_strategy": "epoch", - "save_steps": 10, - } + def test_no_conflict_save_strategy(self, minimal_cfg): + cfg = ( + DictDefault( + { + "save_strategy": "epoch", + "save_steps": 10, + } + ) + | minimal_cfg ) with pytest.raises( @@ -405,11 +624,14 @@ class ValidationTest(BaseValidation): ): validate_config(cfg) - cfg = DictDefault( - { - "save_strategy": "no", - "save_steps": 10, - } + cfg = ( + DictDefault( + { + "save_strategy": "no", + "save_steps": 10, + } + ) + | minimal_cfg ) with pytest.raises( @@ -417,45 +639,60 @@ class ValidationTest(BaseValidation): ): validate_config(cfg) - cfg = DictDefault( - { - "save_strategy": "steps", - } + cfg = ( + DictDefault( + { + "save_strategy": "steps", + } + ) + | minimal_cfg ) validate_config(cfg) - cfg = DictDefault( - { - "save_strategy": "steps", - "save_steps": 10, - } + cfg = ( + DictDefault( + { + "save_strategy": "steps", + "save_steps": 10, + } + ) + | minimal_cfg ) validate_config(cfg) - cfg = DictDefault( - { - "save_steps": 10, - } + cfg = ( + DictDefault( + { + "save_steps": 10, + } + ) + | minimal_cfg ) validate_config(cfg) - cfg = DictDefault( - { - "save_strategy": "no", - } + cfg = ( + DictDefault( + { + "save_strategy": "no", + } + ) + | minimal_cfg ) validate_config(cfg) - def test_no_conflict_eval_strategy(self): - cfg = DictDefault( - { - "evaluation_strategy": "epoch", - "eval_steps": 10, - } + def test_no_conflict_eval_strategy(self, minimal_cfg): + cfg = ( + DictDefault( + { + "evaluation_strategy": "epoch", + "eval_steps": 10, + } + ) + | minimal_cfg ) with pytest.raises( @@ -463,11 +700,14 @@ class ValidationTest(BaseValidation): ): validate_config(cfg) - cfg = DictDefault( - { - "evaluation_strategy": "no", - "eval_steps": 10, - } + cfg = ( + DictDefault( + { + "evaluation_strategy": "no", + "eval_steps": 10, + } + ) + | minimal_cfg ) with pytest.raises( @@ -475,44 +715,59 @@ class ValidationTest(BaseValidation): ): validate_config(cfg) - cfg = DictDefault( - { - "evaluation_strategy": "steps", - } + cfg = ( + DictDefault( + { + "evaluation_strategy": "steps", + } + ) + | minimal_cfg ) validate_config(cfg) - cfg = DictDefault( - { - "evaluation_strategy": "steps", - "eval_steps": 10, - } + cfg = ( + DictDefault( + { + "evaluation_strategy": "steps", + "eval_steps": 10, + } + ) + | minimal_cfg ) validate_config(cfg) - cfg = DictDefault( - { - "eval_steps": 10, - } + cfg = ( + DictDefault( + { + "eval_steps": 10, + } + ) + | minimal_cfg ) validate_config(cfg) - cfg = DictDefault( - { - "evaluation_strategy": "no", - } + cfg = ( + DictDefault( + { + "evaluation_strategy": "no", + } + ) + | minimal_cfg ) validate_config(cfg) - cfg = DictDefault( - { - "evaluation_strategy": "epoch", - "val_set_size": 0, - } + cfg = ( + DictDefault( + { + "evaluation_strategy": "epoch", + "val_set_size": 0, + } + ) + | minimal_cfg ) with pytest.raises( @@ -521,11 +776,14 @@ class ValidationTest(BaseValidation): ): validate_config(cfg) - cfg = DictDefault( - { - "eval_steps": 10, - "val_set_size": 0, - } + cfg = ( + DictDefault( + { + "eval_steps": 10, + "val_set_size": 0, + } + ) + | minimal_cfg ) with pytest.raises( @@ -534,38 +792,50 @@ class ValidationTest(BaseValidation): ): validate_config(cfg) - cfg = DictDefault( - { - "val_set_size": 0, - } + cfg = ( + DictDefault( + { + "val_set_size": 0, + } + ) + | minimal_cfg ) validate_config(cfg) - cfg = DictDefault( - { - "eval_steps": 10, - "val_set_size": 0.01, - } + cfg = ( + DictDefault( + { + "eval_steps": 10, + "val_set_size": 0.01, + } + ) + | minimal_cfg ) validate_config(cfg) - cfg = DictDefault( - { - "evaluation_strategy": "epoch", - "val_set_size": 0.01, - } + cfg = ( + DictDefault( + { + "evaluation_strategy": "epoch", + "val_set_size": 0.01, + } + ) + | minimal_cfg ) validate_config(cfg) - def test_eval_table_size_conflict_eval_packing(self): - cfg = DictDefault( - { - "sample_packing": True, - "eval_table_size": 100, - } + def test_eval_table_size_conflict_eval_packing(self, minimal_cfg): + cfg = ( + DictDefault( + { + "sample_packing": True, + "eval_table_size": 100, + } + ) + | minimal_cfg ) with pytest.raises( @@ -573,39 +843,51 @@ class ValidationTest(BaseValidation): ): validate_config(cfg) - cfg = DictDefault( - { - "sample_packing": True, - "eval_sample_packing": False, - } + cfg = ( + DictDefault( + { + "sample_packing": True, + "eval_sample_packing": False, + } + ) + | minimal_cfg ) validate_config(cfg) - cfg = DictDefault( - { - "sample_packing": False, - "eval_table_size": 100, - } + cfg = ( + DictDefault( + { + "sample_packing": False, + "eval_table_size": 100, + } + ) + | minimal_cfg ) validate_config(cfg) - cfg = DictDefault( - { - "sample_packing": True, - "eval_table_size": 100, - "eval_sample_packing": False, - } + cfg = ( + DictDefault( + { + "sample_packing": True, + "eval_table_size": 100, + "eval_sample_packing": False, + } + ) + | minimal_cfg ) validate_config(cfg) - def test_load_in_x_bit_without_adapter(self): - cfg = DictDefault( - { - "load_in_4bit": True, - } + def test_load_in_x_bit_without_adapter(self, minimal_cfg): + cfg = ( + DictDefault( + { + "load_in_4bit": True, + } + ) + | minimal_cfg ) with pytest.raises( @@ -614,10 +896,13 @@ class ValidationTest(BaseValidation): ): validate_config(cfg) - cfg = DictDefault( - { - "load_in_8bit": True, - } + cfg = ( + DictDefault( + { + "load_in_8bit": True, + } + ) + | minimal_cfg ) with pytest.raises( @@ -626,30 +911,39 @@ class ValidationTest(BaseValidation): ): validate_config(cfg) - cfg = DictDefault( - { - "load_in_4bit": True, - "adapter": "qlora", - } + cfg = ( + DictDefault( + { + "load_in_4bit": True, + "adapter": "qlora", + } + ) + | minimal_cfg ) validate_config(cfg) - cfg = DictDefault( - { - "load_in_8bit": True, - "adapter": "lora", - } + cfg = ( + DictDefault( + { + "load_in_8bit": True, + "adapter": "lora", + } + ) + | minimal_cfg ) validate_config(cfg) - def test_warmup_step_no_conflict(self): - cfg = DictDefault( - { - "warmup_steps": 10, - "warmup_ratio": 0.1, - } + def test_warmup_step_no_conflict(self, minimal_cfg): + cfg = ( + DictDefault( + { + "warmup_steps": 10, + "warmup_ratio": 0.1, + } + ) + | minimal_cfg ) with pytest.raises( @@ -658,29 +952,40 @@ class ValidationTest(BaseValidation): ): validate_config(cfg) - cfg = DictDefault( - { - "warmup_steps": 10, - } + cfg = ( + DictDefault( + { + "warmup_steps": 10, + } + ) + | minimal_cfg ) validate_config(cfg) - cfg = DictDefault( - { - "warmup_ratio": 0.1, - } + cfg = ( + DictDefault( + { + "warmup_ratio": 0.1, + } + ) + | minimal_cfg ) validate_config(cfg) - def test_unfrozen_parameters_w_peft_layers_to_transform(self): - cfg = DictDefault( - { - "adapter": "lora", - "unfrozen_parameters": ["model.layers.2[0-9]+.block_sparse_moe.gate.*"], - "peft_layers_to_transform": [0, 1], - } + def test_unfrozen_parameters_w_peft_layers_to_transform(self, minimal_cfg): + cfg = ( + DictDefault( + { + "adapter": "lora", + "unfrozen_parameters": [ + "model.layers.2[0-9]+.block_sparse_moe.gate.*" + ], + "peft_layers_to_transform": [0, 1], + } + ) + | minimal_cfg ) with pytest.raises( @@ -689,8 +994,8 @@ class ValidationTest(BaseValidation): ): validate_config(cfg) - def test_hub_model_id_save_value_warns(self): - cfg = DictDefault({"hub_model_id": "test"}) + def test_hub_model_id_save_value_warns(self, minimal_cfg): + cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg with self._caplog.at_level(logging.WARNING): validate_config(cfg) @@ -698,22 +1003,25 @@ class ValidationTest(BaseValidation): "set without any models being saved" in self._caplog.records[0].message ) - def test_hub_model_id_save_value(self): - cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) + def test_hub_model_id_save_value(self, minimal_cfg): + cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg with self._caplog.at_level(logging.WARNING): validate_config(cfg) assert len(self._caplog.records) == 0 -class ValidationCheckModelConfig(BaseValidation): +class TestValidationCheckModelConfig(BaseValidation): """ Test the validation for the config when the model config is available """ - def test_llama_add_tokens_adapter(self): - cfg = DictDefault( - {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]} + def test_llama_add_tokens_adapter(self, minimal_cfg): + cfg = ( + DictDefault( + {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]} + ) + | minimal_cfg ) model_config = DictDefault({"model_type": "llama"}) @@ -723,13 +1031,16 @@ class ValidationCheckModelConfig(BaseValidation): ): check_model_config(cfg, model_config) - cfg = DictDefault( - { - "adapter": "qlora", - "load_in_4bit": True, - "tokens": ["<|imstart|>"], - "lora_modules_to_save": ["embed_tokens"], - } + cfg = ( + DictDefault( + { + "adapter": "qlora", + "load_in_4bit": True, + "tokens": ["<|imstart|>"], + "lora_modules_to_save": ["embed_tokens"], + } + ) + | minimal_cfg ) with pytest.raises( @@ -738,20 +1049,26 @@ class ValidationCheckModelConfig(BaseValidation): ): check_model_config(cfg, model_config) - cfg = DictDefault( - { - "adapter": "qlora", - "load_in_4bit": True, - "tokens": ["<|imstart|>"], - "lora_modules_to_save": ["embed_tokens", "lm_head"], - } + cfg = ( + DictDefault( + { + "adapter": "qlora", + "load_in_4bit": True, + "tokens": ["<|imstart|>"], + "lora_modules_to_save": ["embed_tokens", "lm_head"], + } + ) + | minimal_cfg ) check_model_config(cfg, model_config) - def test_phi_add_tokens_adapter(self): - cfg = DictDefault( - {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]} + def test_phi_add_tokens_adapter(self, minimal_cfg): + cfg = ( + DictDefault( + {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]} + ) + | minimal_cfg ) model_config = DictDefault({"model_type": "phi"}) @@ -761,13 +1078,16 @@ class ValidationCheckModelConfig(BaseValidation): ): check_model_config(cfg, model_config) - cfg = DictDefault( - { - "adapter": "qlora", - "load_in_4bit": True, - "tokens": ["<|imstart|>"], - "lora_modules_to_save": ["embd.wte", "lm_head.linear"], - } + cfg = ( + DictDefault( + { + "adapter": "qlora", + "load_in_4bit": True, + "tokens": ["<|imstart|>"], + "lora_modules_to_save": ["embd.wte", "lm_head.linear"], + } + ) + | minimal_cfg ) with pytest.raises( @@ -776,66 +1096,78 @@ class ValidationCheckModelConfig(BaseValidation): ): check_model_config(cfg, model_config) - cfg = DictDefault( - { - "adapter": "qlora", - "load_in_4bit": True, - "tokens": ["<|imstart|>"], - "lora_modules_to_save": ["embed_tokens", "lm_head"], - } + cfg = ( + DictDefault( + { + "adapter": "qlora", + "load_in_4bit": True, + "tokens": ["<|imstart|>"], + "lora_modules_to_save": ["embed_tokens", "lm_head"], + } + ) + | minimal_cfg ) check_model_config(cfg, model_config) -class ValidationWandbTest(BaseValidation): +class TestValidationWandb(BaseValidation): """ Validation test for wandb """ - def test_wandb_set_run_id_to_name(self): - cfg = DictDefault( - { - "wandb_run_id": "foo", - } + def test_wandb_set_run_id_to_name(self, minimal_cfg): + cfg = ( + DictDefault( + { + "wandb_run_id": "foo", + } + ) + | minimal_cfg ) with self._caplog.at_level(logging.WARNING): - validate_config(cfg) + new_cfg = validate_config(cfg) assert any( "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead." in record.message for record in self._caplog.records ) - assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo" + assert new_cfg.wandb_name == "foo" and new_cfg.wandb_run_id == "foo" - cfg = DictDefault( - { - "wandb_name": "foo", - } + cfg = ( + DictDefault( + { + "wandb_name": "foo", + } + ) + | minimal_cfg ) - validate_config(cfg) + new_cfg = validate_config(cfg) - assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None + assert new_cfg.wandb_name == "foo" and new_cfg.wandb_run_id is None - def test_wandb_sets_env(self): - cfg = DictDefault( - { - "wandb_project": "foo", - "wandb_name": "bar", - "wandb_run_id": "bat", - "wandb_entity": "baz", - "wandb_mode": "online", - "wandb_watch": "false", - "wandb_log_model": "checkpoint", - } + def test_wandb_sets_env(self, minimal_cfg): + cfg = ( + DictDefault( + { + "wandb_project": "foo", + "wandb_name": "bar", + "wandb_run_id": "bat", + "wandb_entity": "baz", + "wandb_mode": "online", + "wandb_watch": "false", + "wandb_log_model": "checkpoint", + } + ) + | minimal_cfg ) - validate_config(cfg) + new_cfg = validate_config(cfg) - setup_wandb_env_vars(cfg) + setup_wandb_env_vars(new_cfg) assert os.environ.get("WANDB_PROJECT", "") == "foo" assert os.environ.get("WANDB_NAME", "") == "bar" @@ -855,24 +1187,27 @@ class ValidationWandbTest(BaseValidation): os.environ.pop("WANDB_LOG_MODEL", None) os.environ.pop("WANDB_DISABLED", None) - def test_wandb_set_disabled(self): - cfg = DictDefault({}) + def test_wandb_set_disabled(self, minimal_cfg): + cfg = DictDefault({}) | minimal_cfg - validate_config(cfg) + new_cfg = validate_config(cfg) - setup_wandb_env_vars(cfg) + setup_wandb_env_vars(new_cfg) assert os.environ.get("WANDB_DISABLED", "") == "true" - cfg = DictDefault( - { - "wandb_project": "foo", - } + cfg = ( + DictDefault( + { + "wandb_project": "foo", + } + ) + | minimal_cfg ) - validate_config(cfg) + new_cfg = validate_config(cfg) - setup_wandb_env_vars(cfg) + setup_wandb_env_vars(new_cfg) assert os.environ.get("WANDB_DISABLED", "") != "true"