diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index cd7fd9bee..cca8b92a1 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -126,12 +126,8 @@ class DeprecatedParameters(BaseModel): class RemappedParameters(BaseModel): """parameters that have been remapped to other names""" - overrides_of_model_config: Optional[Dict[str, Any]] = Field( - default=None, alias="model_config" - ) - overrides_of_model_kwargs: Optional[Dict[str, Any]] = Field( - default=None, alias="model_kwargs" - ) + overrides_of_model_config: Optional[Dict[str, Any]] = Field(default=None, alias="model_config") + overrides_of_model_kwargs: Optional[Dict[str, Any]] = Field(default=None, alias="model_kwargs") type_of_model: Optional[str] = Field(default=None, alias="model_type") revision_of_model: Optional[str] = Field(default=None, alias="model_revision") @@ -200,12 +196,8 @@ class SFTDataset(BaseModel): field_human: Optional[str] = None field_model: Optional[str] = None field_messages: Optional[str] = None - message_field_role: Optional[ - str - ] = None # deprecated, use message_property_mappings - message_field_content: Optional[ - str - ] = None # deprecated, use message_property_mappings + message_field_role: Optional[str] = None # deprecated, use message_property_mappings + message_field_content: Optional[str] = None # deprecated, use message_property_mappings message_property_mappings: Optional[Dict[str, str]] = None message_field_training: Optional[str] = None message_field_training_detail: Optional[str] = None @@ -235,12 +227,8 @@ class SFTDataset(BaseModel): data["chat_template"] = ChatTemplate.tokenizer_default # if chat_template is set to jinja, chat_template_jinja is required - if data.get("chat_template") == ChatTemplate.jinja and not data.get( - "chat_template_jinja" - ): - raise ValueError( - "chat_template_jinja is required when chat_template is set to jinja" - ) + if data.get("chat_template") == ChatTemplate.jinja and not data.get("chat_template_jinja"): + raise ValueError("chat_template_jinja is required when chat_template is set to jinja") # If chat_template_jinja is set, set chat_template to jinja if data.get("chat_template_jinja") and not data.get("chat_template"): @@ -312,9 +300,7 @@ DatasetConfig = Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedData class LoftQConfig(BaseModel): """LoftQ configuration subset""" - loftq_bits: int = Field( - default=4, json_schema_extra={"description": "Quantization bits for LoftQ"} - ) + loftq_bits: int = Field(default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}) # loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"}) @@ -359,9 +345,7 @@ class LoraConfig(BaseModel): qlora_sharded_model_loading: Optional[bool] = Field( default=False, - json_schema_extra={ - "description": "load qlora model in sharded format for FSDP using answer.ai technique." - }, + json_schema_extra={"description": "load qlora model in sharded format for FSDP using answer.ai technique."}, ) lora_on_cpu: Optional[bool] = None gptq: Optional[bool] = None @@ -369,15 +353,11 @@ class LoraConfig(BaseModel): loraplus_lr_ratio: Optional[float] = Field( default=None, - json_schema_extra={ - "description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4." - }, + json_schema_extra={"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."}, ) loraplus_lr_embedding: Optional[float] = Field( default=1e-6, - json_schema_extra={ - "description": "loraplus learning rate for lora embedding layers." - }, + json_schema_extra={"description": "loraplus learning rate for lora embedding layers."}, ) merge_lora: Optional[bool] = None @@ -485,15 +465,11 @@ class HyperparametersConfig(BaseModel): ) batch_size: Optional[int] = Field( default=None, - json_schema_extra={ - "description": "Total batch size, we do not recommended setting this manually" - }, + json_schema_extra={"description": "Total batch size, we do not recommended setting this manually"}, ) eval_batch_size: Optional[int] = Field( default=None, - json_schema_extra={ - "description": "per gpu micro batch size for evals, defaults to value of micro_batch_size" - }, + json_schema_extra={"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"}, ) auto_find_batch_size: Optional[bool] = None @@ -505,9 +481,7 @@ class HyperparametersConfig(BaseModel): embedding_lr: Optional[float] = None embedding_lr_scale: Optional[float] = None weight_decay: Optional[float] = 0.0 - optimizer: Optional[ - Union[OptimizerNames, CustomSupportedOptimizers] - ] = OptimizerNames.ADAMW_HF + optimizer: Optional[Union[OptimizerNames, CustomSupportedOptimizers]] = OptimizerNames.ADAMW_HF optim_args: Optional[Union[str, Dict[str, Any]]] = Field( default=None, json_schema_extra={"description": "Optional arguments to supply to optimizer."}, @@ -519,9 +493,7 @@ class HyperparametersConfig(BaseModel): }, ) torchdistx_path: Optional[str] = None - lr_scheduler: Optional[ - Union[SchedulerType, Literal["one_cycle"], Literal["rex"]] - ] = SchedulerType.COSINE + lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]] = SchedulerType.COSINE lr_scheduler_kwargs: Optional[Dict[str, Any]] = None lr_quadratic_warmup: Optional[bool] = None cosine_min_lr_ratio: Optional[float] = None @@ -645,21 +617,15 @@ class RayConfig(BaseModel): use_ray: bool = Field(default=False) ray_run_name: Optional[str] = Field( default=None, - json_schema_extra={ - "help": "The training results will be saved at `saves/ray_run_name`." - }, + json_schema_extra={"help": "The training results will be saved at `saves/ray_run_name`."}, ) ray_num_workers: int = Field( default=1, - json_schema_extra={ - "help": "The number of workers for Ray training. Default is 1 worker." - }, + json_schema_extra={"help": "The number of workers for Ray training. Default is 1 worker."}, ) resources_per_worker: dict = Field( default_factory=lambda: {"GPU": 1}, - json_schema_extra={ - "help": "The resources per worker for Ray training. Default is to use 1 GPU per worker." - }, + json_schema_extra={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, ) @@ -699,9 +665,9 @@ class AxolotlInputConfig( reward_model: Optional[bool] = None process_reward_model: Optional[bool] = None num_labels: Optional[int] = None - dpo_use_weighting: Optional[ - bool - ] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer. + dpo_use_weighting: Optional[bool] = ( + None # whether to use weighting in DPO trainer. If none, default is false in the trainer. + ) dpo_use_logits_to_keep: Optional[bool] = None datasets: Optional[ @@ -723,9 +689,7 @@ class AxolotlInputConfig( dataset_shard_idx: Optional[int] = None skip_prepare_dataset: Optional[bool] = False - pretraining_dataset: Optional[ - Annotated[list[Union[PretrainingDataset, SFTDataset]], MinLen(1)] - ] = Field( + pretraining_dataset: Optional[Annotated[list[Union[PretrainingDataset, SFTDataset]], MinLen(1)]] = Field( default=None, json_schema_extra={"description": "streaming dataset to use for pretraining"}, ) @@ -780,9 +744,7 @@ class AxolotlInputConfig( # torch_dtype: Optional[torch.dtype] - gradient_checkpointing: Optional[ - Union[Literal["unsloth", "offload"], bool] - ] = Field(default=False) + gradient_checkpointing: Optional[Union[Literal["unsloth", "offload"], bool]] = Field(default=False) gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None unfrozen_parameters: Optional[List[str]] = None @@ -849,9 +811,7 @@ class AxolotlInputConfig( deepspeed: Optional[Union[str, Dict[str, Any]]] = None fsdp: Optional[List[str]] = None fsdp_config: Optional[Dict[str, Any]] = None - fsdp_final_state_dict_type: Optional[ - Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] - ] = None + fsdp_final_state_dict_type: Optional[Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]] = None val_set_size: Optional[float] = Field(default=0.0) @@ -861,9 +821,7 @@ class AxolotlInputConfig( torch_compile: Optional[Union[Literal["auto"], bool]] = None torch_compile_backend: Optional[str] = None - torch_compile_mode: Optional[ - Literal["default", "reduce-overhead", "max-autotune"] - ] = None + torch_compile_mode: Optional[Literal["default", "reduce-overhead", "max-autotune"]] = None max_steps: Optional[int] = None warmup_steps: Optional[int] = None @@ -894,9 +852,7 @@ class AxolotlInputConfig( kto_undesirable_weight: Optional[float] = None rl_beta: Optional[float] = None - max_memory: Optional[ - Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]] - ] = None + max_memory: Optional[Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]] = None gpu_memory_limit: Optional[Union[int, str]] = None low_cpu_mem_usage: Optional[bool] = None @@ -932,11 +888,7 @@ class AxolotlInputConfig( def deprecate_sharegpt_datasets(cls, datasets): for _, ds_cfg in enumerate(datasets): # Handle both dict and pydantic model cases - ds_type = ( - ds_cfg.get("type") - if isinstance(ds_cfg, dict) - else getattr(ds_cfg, "type", None) - ) + ds_type = ds_cfg.get("type") if isinstance(ds_cfg, dict) else getattr(ds_cfg, "type", None) if not ds_type: continue @@ -945,16 +897,12 @@ class AxolotlInputConfig( continue if isinstance(ds_type, str) and ds_type.startswith("sharegpt"): - raise ValueError( - "`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead." - ) + raise ValueError("`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead.") return datasets @field_serializer("datasets") - def datasets_serializer( - self, ds_configs: Optional[List[DatasetConfig]] - ) -> Optional[List[Dict[str, Any]]]: + def datasets_serializer(self, ds_configs: Optional[List[DatasetConfig]]) -> Optional[List[Dict[str, Any]]]: if ds_configs: return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs] return None @@ -1020,9 +968,7 @@ class AxolotlInputConfig( @classmethod def check_sample_packing_w_xformers(cls, data): if data.get("sample_packing") and data.get("xformers_attention"): - raise ValueError( - "sample_packing not compatible with xformers_attention. Use flash_attention" - ) + raise ValueError("sample_packing not compatible with xformers_attention. Use flash_attention") return data @@ -1030,12 +976,8 @@ class AxolotlInputConfig( @classmethod def check_chat_template_config(cls, data): # if chat_template is set to jinja, chat_template_jinja is required - if data.get("chat_template") == ChatTemplate.jinja and not data.get( - "chat_template_jinja" - ): - raise ValueError( - "chat_template_jinja is required when chat_template is set to jinja" - ) + if data.get("chat_template") == ChatTemplate.jinja and not data.get("chat_template_jinja"): + raise ValueError("chat_template_jinja is required when chat_template is set to jinja") # If chat_template_jinja is set, set chat_template to jinja if data.get("chat_template_jinja") and not data.get("chat_template"): @@ -1046,14 +988,8 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_sample_packing_wo_flash(cls, data): - if ( - data.get("sample_packing") - and not data.get("flash_attention") - and not data.get("sdp_attention") - ): - LOG.warning( - "sample_packing without flash_attention or sdp_attention does not handle cross-attention." - ) + if data.get("sample_packing") and not data.get("flash_attention") and not data.get("sdp_attention"): + LOG.warning("sample_packing without flash_attention or sdp_attention does not handle cross-attention.") return data @@ -1092,18 +1028,14 @@ class AxolotlInputConfig( @classmethod def hint_sample_packing_padding(cls, data): if data.get("sample_packing") and not data.get("pad_to_sequence_len"): - LOG.warning( - "`pad_to_sequence_len: true` is recommended when using sample_packing" - ) + LOG.warning("`pad_to_sequence_len: true` is recommended when using sample_packing") return data @model_validator(mode="before") @classmethod def hint_reward_model_pad(cls, data): if data.get("reward_model") and not data.get("pad_to_sequence_len"): - LOG.warning( - "`pad_to_sequence_len: true` is recommended when using reward_model" - ) + LOG.warning("`pad_to_sequence_len: true` is recommended when using reward_model") if data.get("pad_to_sequence_len") is None: data["pad_to_sequence_len"] = True return data @@ -1112,9 +1044,7 @@ class AxolotlInputConfig( @classmethod def check_gas_bsz(cls, data): if data.get("gradient_accumulation_steps") and data.get("batch_size"): - raise ValueError( - "please set only one of gradient_accumulation_steps or batch_size" - ) + raise ValueError("please set only one of gradient_accumulation_steps or batch_size") return data @model_validator(mode="before") @@ -1125,21 +1055,14 @@ class AxolotlInputConfig( and data.get("micro_batch_size") and data.get("eval_batch_size") != data.get("micro_batch_size") ): - LOG.warning( - "eval_batch_size != micro_batch_size. This can lead to VRAM instability." - ) + LOG.warning("eval_batch_size != micro_batch_size. This can lead to VRAM instability.") return data @model_validator(mode="before") @classmethod def check_push_ds_auth(cls, data): - if ( - data.get("push_dataset_to_hub") - and data.get("hf_use_auth_token") is not True - ): - raise ValueError( - "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" - ) + if data.get("push_dataset_to_hub") and data.get("hf_use_auth_token") is not True: + raise ValueError("Require cfg.hf_use_auth_token to be True for push_dataset_to_hub") return data @model_validator(mode="after") @@ -1150,18 +1073,14 @@ class AxolotlInputConfig( @model_validator(mode="after") def check_mpt_checkpointing(self): - if ( - self.base_model and "mpt" in self.base_model.lower() - ) and self.gradient_checkpointing: + if (self.base_model and "mpt" in self.base_model.lower()) and self.gradient_checkpointing: raise ValueError("gradient_checkpointing is not supported for MPT models") return self @model_validator(mode="after") def check_offload_grad_checkpointing(self): if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth": - LOG.warning( - "`unsloth` is deprecated for gradient_checkpointing, use `offload`" - ) + LOG.warning("`unsloth` is deprecated for gradient_checkpointing, use `offload`") self.gradient_checkpointing = "offload" return self @@ -1169,9 +1088,7 @@ class AxolotlInputConfig( def check_better_transformers(self): if self.flash_optimum is True: if self.adapter: - LOG.warning( - "BetterTransformers probably doesn't work with PEFT adapters" - ) + LOG.warning("BetterTransformers probably doesn't work with PEFT adapters") if self.fp16 or self.bf16: raise ValueError("AMP is not supported with BetterTransformer") if self.float16 is not True and self.bfloat16 is not True: @@ -1199,39 +1116,25 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_saves(cls, data): - if ( - data.get("save_strategy") - and data.get("save_steps") - and data.get("save_strategy") != "steps" - ): + if data.get("save_strategy") and data.get("save_steps") and data.get("save_strategy") != "steps": raise ValueError( "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." ) if data.get("saves_per_epoch") and data.get("save_steps"): - raise ValueError( - "save_steps and saves_per_epoch are mutually exclusive and cannot be used together." - ) + raise ValueError("save_steps and saves_per_epoch are mutually exclusive and cannot be used together.") return data @model_validator(mode="before") @classmethod def check_push_save(cls, data): - if data.get("hub_model_id") and ( - data.get("save_strategy") not in ["steps", "epoch", None] - ): - LOG.warning( - "hub_model_id is set without any models being saved. To save a model, set save_strategy." - ) + if data.get("hub_model_id") and (data.get("save_strategy") not in ["steps", "epoch", None]): + LOG.warning("hub_model_id is set without any models being saved. To save a model, set save_strategy.") return data @model_validator(mode="before") @classmethod def check_evals(cls, data): - if ( - data.get("eval_strategy") - and data.get("eval_steps") - and data.get("eval_strategy") != "steps" - ): + if data.get("eval_strategy") and data.get("eval_steps") and data.get("eval_strategy") != "steps": raise ValueError( "eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps." ) @@ -1241,41 +1144,21 @@ class AxolotlInputConfig( and (data.get("eval_steps") or data.get("eval_strategy")) and not data.get("test_datasets") ): - raise ValueError( - "eval_steps and eval_strategy are not supported with val_set_size == 0" - ) + raise ValueError("eval_steps and eval_strategy are not supported with val_set_size == 0") if data.get("evals_per_epoch") and data.get("eval_steps"): - raise ValueError( - "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together." - ) - if ( - data.get("evals_per_epoch") - and data.get("eval_strategy") - and data.get("eval_strategy") != "steps" - ): - raise ValueError( - "eval_strategy must be empty or set to `steps` when used with evals_per_epoch." - ) + raise ValueError("eval_steps and evals_per_epoch are mutually exclusive and cannot be used together.") + if data.get("evals_per_epoch") and data.get("eval_strategy") and data.get("eval_strategy") != "steps": + raise ValueError("eval_strategy must be empty or set to `steps` when used with evals_per_epoch.") - if data.get("do_bench_eval") and not ( - data.get("evals_per_epoch") or data.get("eval_steps") - ): - raise ValueError( - "do_bench_eval requires evals_per_epoch or eval_steps to be set." - ) + if data.get("do_bench_eval") and not (data.get("evals_per_epoch") or data.get("eval_steps")): + raise ValueError("do_bench_eval requires evals_per_epoch or eval_steps to be set.") return data @model_validator(mode="before") @classmethod def check_test_datasets_bench(cls, data): - if ( - data.get("do_bench_eval") - and not data.get("test_datasets") - and not data.get("val_set_size") - ): - LOG.warning( - "`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset." - ) + if data.get("do_bench_eval") and not data.get("test_datasets") and not data.get("val_set_size"): + LOG.warning("`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset.") data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}] return data @@ -1284,22 +1167,12 @@ class AxolotlInputConfig( def check_eval_packing(cls, data): # TODO also should check test_datasets and val_set_size as we can skip # if there are no eval datasets/splits - if ( - data.get("sample_packing") - and data.get("eval_table_size") - and data.get("eval_sample_packing") is not False - ): + if data.get("sample_packing") and data.get("eval_table_size") and data.get("eval_sample_packing") is not False: raise ValueError( "eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false." ) - if ( - data.get("sample_packing") - and data.get("eval_sample_packing") is None - and not data.get("eval_table_size") - ): - LOG.info( - "explicitly setting `eval_sample_packing` to match `sample_packing`" - ) + if data.get("sample_packing") and data.get("eval_sample_packing") is None and not data.get("eval_table_size"): + LOG.info("explicitly setting `eval_sample_packing` to match `sample_packing`") data["eval_sample_packing"] = True if ( @@ -1319,9 +1192,7 @@ class AxolotlInputConfig( def check_mm_prepare(cls, data): if data.get("skip_prepare_dataset"): if data.get("remove_unused_columns") is None: - LOG.info( - "setting `remove_unused_columns: false` for skip_prepare_dataset" - ) + LOG.info("setting `remove_unused_columns: false` for skip_prepare_dataset") data["remove_unused_columns"] = False return data @@ -1362,19 +1233,13 @@ class AxolotlInputConfig( @model_validator(mode="after") def check_simpo_warmup(self): if self.rl == "simpo" and self.warmup_ratio: - raise ValueError( - "warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead" - ) + raise ValueError("warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead") return self @model_validator(mode="before") @classmethod def check_frozen(cls, data): - if ( - data.get("adapter") - and data.get("peft_layers_to_transform") - and data.get("unfrozen_parameters") - ): + if data.get("adapter") and data.get("peft_layers_to_transform") and data.get("unfrozen_parameters"): raise ValueError( "`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior." ) @@ -1385,9 +1250,7 @@ class AxolotlInputConfig( @classmethod def check_peft_layers_pattern(cls, data): if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"): - raise ValueError( - "peft_layers_pattern requires peft_layers_to_transform to be set" - ) + raise ValueError("peft_layers_pattern requires peft_layers_to_transform to be set") return data @model_validator(mode="after") @@ -1410,17 +1273,13 @@ class AxolotlInputConfig( @model_validator(mode="after") def check_fused_lora(self): - if self.adapter in ["lora", "qlora"] and ( - self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp - ): + if self.adapter in ["lora", "qlora"] and (self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp): raise ValueError("Fused modules are not supported with LoRA/QLoRA") return self @model_validator(mode="after") def hint_lora_8bit(self): - loftq = ( - self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits - ) + loftq = self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits if not self.load_in_8bit and self.adapter == "lora" and not loftq: LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") return self @@ -1433,9 +1292,7 @@ class AxolotlInputConfig( "`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps." ) if self.save_steps % self.eval_steps != 0: - raise ValueError( - "`early_stopping_patience` requires that eval_steps should evenly divide save_steps." - ) + raise ValueError("`early_stopping_patience` requires that eval_steps should evenly divide save_steps.") return self @model_validator(mode="after") @@ -1451,9 +1308,7 @@ class AxolotlInputConfig( raise ValueError("deepspeed not supported with ReLoRA") if self.lr_scheduler == "one_cycle": - raise ValueError( - "ReLoRA is not compatible with the one_cycle scheduler" - ) + raise ValueError("ReLoRA is not compatible with the one_cycle scheduler") if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp: raise ValueError("Fused modules are not supported with ReLoRA") @@ -1462,13 +1317,8 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_mem_mismatch(cls, data): - if ( - data.get("max_memory") is not None - and data.get("gpu_memory_limit") is not None - ): - raise ValueError( - "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together." - ) + if data.get("max_memory") is not None and data.get("gpu_memory_limit") is not None: + raise ValueError("max_memory and gpu_memory_limit are mutually exclusive and cannot be used together.") return data @model_validator(mode="before") @@ -1477,13 +1327,10 @@ class AxolotlInputConfig( if ( data.get("unfrozen_parameters") and data.get("gradient_checkpointing_kwargs") - and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") - is True + and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is True ): # https://github.com/huggingface/transformers/issues/21381 - raise ValueError( - "`use_reentrant` must be false when used with partially frozen model." - ) + raise ValueError("`use_reentrant` must be false when used with partially frozen model.") return data @model_validator(mode="before") @@ -1492,8 +1339,7 @@ class AxolotlInputConfig( if ( data.get("adapter") == "qlora" and data.get("gradient_checkpointing_kwargs", {}) - and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") - is False + and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is False and data.get("deepspeed", "") is not None and "zero3" in data.get("deepspeed", "") ): @@ -1510,21 +1356,14 @@ class AxolotlInputConfig( @classmethod def check_val_w_test_datasets(cls, data): if data.get("test_datasets") and data.get("val_set_size"): - raise ValueError( - "non-zero val_set_size should not be used with test_datasets configuration" - ) + raise ValueError("non-zero val_set_size should not be used with test_datasets configuration") return data @model_validator(mode="before") @classmethod def check_eval_strategy(cls, data): - if ( - data.get("evaluation_strategy") is not None - and data.get("eval_strategy") is None - ): - LOG.info( - "explicitly setting `eval_strategy` from the `evaluation_strategy`" - ) + if data.get("evaluation_strategy") is not None and data.get("eval_strategy") is None: + LOG.info("explicitly setting `eval_strategy` from the `evaluation_strategy`") data["eval_strategy"] = data.get("evaluation_strategy") return data @@ -1537,9 +1376,7 @@ class AxolotlInputConfig( and data.get("fsdp_config") and data["fsdp_config"].get("fsdp_offload_params") ): - raise ValueError( - f"FSDP Offload not compatible with {data.get('optimizer')}" - ) + raise ValueError(f"FSDP Offload not compatible with {data.get('optimizer')}") return data @model_validator(mode="before") @@ -1551,27 +1388,21 @@ class AxolotlInputConfig( and data.get("fsdp_config") and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT" ): - raise ValueError( - "FSDP SHARDED_STATE_DICT not compatible with save_safetensors" - ) + raise ValueError("FSDP SHARDED_STATE_DICT not compatible with save_safetensors") return data @model_validator(mode="before") @classmethod def check_causal_lm_evals(cls, data): if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"): - raise ValueError( - "do_causal_lm_eval is enabled, eval_sample_packing must be set to False" - ) + raise ValueError("do_causal_lm_eval is enabled, eval_sample_packing must be set to False") if data.get("eval_causal_lm_metrics"): if not isinstance(data.get("eval_causal_lm_metrics"), list): raise ValueError("eval_causal_lm_metrics must be a list") # only ["sacrebleu", "comet", "ter", "chrf"] supported if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS: - raise ValueError( - f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}" - ) + raise ValueError(f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}") return data @model_validator(mode="before") @@ -1584,22 +1415,14 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_xentropy_patch_conflicts(cls, data): - if data.get("flash_attn_cross_entropy") and data.get( - "unsloth_cross_entropy_loss" - ): - raise ValueError( - "flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled" - ) + if data.get("flash_attn_cross_entropy") and data.get("unsloth_cross_entropy_loss"): + raise ValueError("flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled") return data @model_validator(mode="before") @classmethod def check_qlora_unsloth(cls, data): - if ( - data.get("unsloth_lora_mlp") - or data.get("unsloth_lora_qkv") - or data.get("unsloth_lora_o") - ): + if data.get("unsloth_lora_mlp") or data.get("unsloth_lora_qkv") or data.get("unsloth_lora_o"): if data.get("adapter") == "lora" and data.get("load_in_8bit"): raise ValueError( "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA" @@ -1609,11 +1432,7 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_lora_8bit(cls, data): - if ( - data.get("lora_mlp_kernel") - or data.get("lora_qkv_kernel") - or data.get("lora_o_kernel") - ): + if data.get("lora_mlp_kernel") or data.get("lora_qkv_kernel") or data.get("lora_o_kernel"): if data.get("adapter") == "lora" and data.get("load_in_8bit"): raise ValueError( "lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA" @@ -1623,13 +1442,8 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_lora_axolotl_unsloth(cls, data): - is_lora_kernel = any( - data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"] - ) - is_unsloth_lora = any( - data.get(k) - for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"] - ) + is_lora_kernel = any(data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]) + is_unsloth_lora = any(data.get(k) for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]) if is_lora_kernel and is_unsloth_lora: raise ValueError( "both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)" @@ -1640,9 +1454,7 @@ class AxolotlInputConfig( @classmethod def check_torch_compile_deepspeed(cls, data): if data.get("deepspeed") and data.get("torch_compile"): - raise ValueError( - "torch_compile should be set within your deepspeed config file" - ) + raise ValueError("torch_compile should be set within your deepspeed config file") return data @model_validator(mode="before") @@ -1692,6 +1504,7 @@ class AxolotlInputConfig( and "use_reentrant" in data.get("gradient_checkpointing_kwargs") and data.get("gradient_checkpointing_kwargs").get("use_reentrant") and data.get("load_in_4bit") + and data.get("adapter") == "qlora" and data.get("capabilities") and data.get("capabilities").get("n_gpu", 1) > 1 ): @@ -1725,15 +1538,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): def check_bf16(self): if self.capabilities.bf16: if not self.bf16 and not self.bfloat16: - LOG.info( - "bf16 support detected, but not enabled for this configuration." - ) + LOG.info("bf16 support detected, but not enabled for this configuration.") else: - if ( - not self.merge_lora - and not self.is_preprocess - and (self.bf16 is True or self.bfloat16 is True) - ): + if not self.merge_lora and not self.is_preprocess and (self.bf16 is True or self.bfloat16 is True): raise ValueError( "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above." ) @@ -1742,10 +1549,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): @model_validator(mode="before") @classmethod def check_sample_packing_w_sdpa_bf16(cls, data): - is_sm_90: bool = ( - data["capabilities"] - and data["capabilities"].get("compute_capability") == "sm_90" - ) + is_sm_90: bool = data["capabilities"] and data["capabilities"].get("compute_capability") == "sm_90" if ( data.get("sample_packing") and data.get("sdp_attention") @@ -1770,11 +1574,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): @model_validator(mode="before") @classmethod def check_multigpu_unsloth(cls, data): - if ( - data.get("unsloth_lora_mlp") - or data.get("unsloth_lora_qkv") - or data.get("unsloth_lora_o") - ): + if data.get("unsloth_lora_mlp") or data.get("unsloth_lora_qkv") or data.get("unsloth_lora_o"): capabilities = data.get("capabilities") if capabilities and capabilities.get("n_gpu", 0) > 1: raise ValueError( @@ -1785,11 +1585,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): @model_validator(mode="before") @classmethod def check_multigpu_lora_kernels(cls, data): - if ( - data.get("lora_mlp_kernel") - or data.get("lora_qkv_kernel") - or data.get("lora_o_kernel") - ): + if data.get("lora_mlp_kernel") or data.get("lora_qkv_kernel") or data.get("lora_o_kernel"): capabilities = data.get("capabilities") is_fsdp = data.get("fsdp") is not None is_deepspeed = data.get("deepspeed") is not None @@ -1818,9 +1614,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): torch_version = str(torch.__version__).split("+", maxsplit=1)[0] if version.parse(torch_version) < version.parse("2.5.1"): - raise ValueError( - "ADOPT optimizer is incompatible with torch version < 2.5.1" - ) + raise ValueError("ADOPT optimizer is incompatible with torch version < 2.5.1") return data @model_validator(mode="before") @@ -1829,12 +1623,8 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): if data.get("torch_compile") == "auto": env_capabilities = data.get("env_capabilities", {}) if env_capabilities.get("torch_version"): - if version.parse( - env_capabilities.get("torch_version") - ) >= version.parse("2.5.1"): - LOG.info( - "torch.compile is available, setting torch_compile to True" - ) + if version.parse(env_capabilities.get("torch_version")) >= version.parse("2.5.1"): + LOG.info("torch.compile is available, setting torch_compile to True") data["torch_compile"] = True else: data["torch_compile"] = False @@ -1899,16 +1689,13 @@ def handle_legacy_message_fields_logic(data: dict) -> dict: ) if ( "content" in data["message_property_mappings"] - and data["message_property_mappings"]["content"] - != data["message_field_content"] + and data["message_property_mappings"]["content"] != data["message_field_content"] ): raise ValueError( f"Conflicting message content fields: message_field_content='{data['message_field_content']}' " f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'" ) - data["message_property_mappings"]["content"] = ( - data["message_field_content"] or "content" - ) + data["message_property_mappings"]["content"] = data["message_field_content"] or "content" del data["message_field_content"] elif "content" not in data["message_property_mappings"]: