From 690908cf2fe57a496ebad719e86457021d02f034 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 18 Mar 2025 11:23:23 +0000 Subject: [PATCH] linting --- .../config/models/input/v0_4_1/__init__.py | 404 ++++++++++++++---- 1 file changed, 309 insertions(+), 95 deletions(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index cca8b92a1..897fe760e 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -126,8 +126,12 @@ class DeprecatedParameters(BaseModel): class RemappedParameters(BaseModel): """parameters that have been remapped to other names""" - overrides_of_model_config: Optional[Dict[str, Any]] = Field(default=None, alias="model_config") - overrides_of_model_kwargs: Optional[Dict[str, Any]] = Field(default=None, alias="model_kwargs") + overrides_of_model_config: Optional[Dict[str, Any]] = Field( + default=None, alias="model_config" + ) + overrides_of_model_kwargs: Optional[Dict[str, Any]] = Field( + default=None, alias="model_kwargs" + ) type_of_model: Optional[str] = Field(default=None, alias="model_type") revision_of_model: Optional[str] = Field(default=None, alias="model_revision") @@ -196,8 +200,12 @@ class SFTDataset(BaseModel): field_human: Optional[str] = None field_model: Optional[str] = None field_messages: Optional[str] = None - message_field_role: Optional[str] = None # deprecated, use message_property_mappings - message_field_content: Optional[str] = None # deprecated, use message_property_mappings + message_field_role: Optional[ + str + ] = None # deprecated, use message_property_mappings + message_field_content: Optional[ + str + ] = None # deprecated, use message_property_mappings message_property_mappings: Optional[Dict[str, str]] = None message_field_training: Optional[str] = None message_field_training_detail: Optional[str] = None @@ -227,8 +235,12 @@ class SFTDataset(BaseModel): data["chat_template"] = ChatTemplate.tokenizer_default # if chat_template is set to jinja, chat_template_jinja is required - if data.get("chat_template") == ChatTemplate.jinja and not data.get("chat_template_jinja"): - raise ValueError("chat_template_jinja is required when chat_template is set to jinja") + if data.get("chat_template") == ChatTemplate.jinja and not data.get( + "chat_template_jinja" + ): + raise ValueError( + "chat_template_jinja is required when chat_template is set to jinja" + ) # If chat_template_jinja is set, set chat_template to jinja if data.get("chat_template_jinja") and not data.get("chat_template"): @@ -300,7 +312,9 @@ DatasetConfig = Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedData class LoftQConfig(BaseModel): """LoftQ configuration subset""" - loftq_bits: int = Field(default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}) + loftq_bits: int = Field( + default=4, json_schema_extra={"description": "Quantization bits for LoftQ"} + ) # loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"}) @@ -345,7 +359,9 @@ class LoraConfig(BaseModel): qlora_sharded_model_loading: Optional[bool] = Field( default=False, - json_schema_extra={"description": "load qlora model in sharded format for FSDP using answer.ai technique."}, + json_schema_extra={ + "description": "load qlora model in sharded format for FSDP using answer.ai technique." + }, ) lora_on_cpu: Optional[bool] = None gptq: Optional[bool] = None @@ -353,11 +369,15 @@ class LoraConfig(BaseModel): loraplus_lr_ratio: Optional[float] = Field( default=None, - json_schema_extra={"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."}, + json_schema_extra={ + "description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4." + }, ) loraplus_lr_embedding: Optional[float] = Field( default=1e-6, - json_schema_extra={"description": "loraplus learning rate for lora embedding layers."}, + json_schema_extra={ + "description": "loraplus learning rate for lora embedding layers." + }, ) merge_lora: Optional[bool] = None @@ -465,11 +485,15 @@ class HyperparametersConfig(BaseModel): ) batch_size: Optional[int] = Field( default=None, - json_schema_extra={"description": "Total batch size, we do not recommended setting this manually"}, + json_schema_extra={ + "description": "Total batch size, we do not recommended setting this manually" + }, ) eval_batch_size: Optional[int] = Field( default=None, - json_schema_extra={"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"}, + json_schema_extra={ + "description": "per gpu micro batch size for evals, defaults to value of micro_batch_size" + }, ) auto_find_batch_size: Optional[bool] = None @@ -481,7 +505,9 @@ class HyperparametersConfig(BaseModel): embedding_lr: Optional[float] = None embedding_lr_scale: Optional[float] = None weight_decay: Optional[float] = 0.0 - optimizer: Optional[Union[OptimizerNames, CustomSupportedOptimizers]] = OptimizerNames.ADAMW_HF + optimizer: Optional[ + Union[OptimizerNames, CustomSupportedOptimizers] + ] = OptimizerNames.ADAMW_HF optim_args: Optional[Union[str, Dict[str, Any]]] = Field( default=None, json_schema_extra={"description": "Optional arguments to supply to optimizer."}, @@ -493,7 +519,9 @@ class HyperparametersConfig(BaseModel): }, ) torchdistx_path: Optional[str] = None - lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]] = SchedulerType.COSINE + lr_scheduler: Optional[ + Union[SchedulerType, Literal["one_cycle"], Literal["rex"]] + ] = SchedulerType.COSINE lr_scheduler_kwargs: Optional[Dict[str, Any]] = None lr_quadratic_warmup: Optional[bool] = None cosine_min_lr_ratio: Optional[float] = None @@ -617,15 +645,21 @@ class RayConfig(BaseModel): use_ray: bool = Field(default=False) ray_run_name: Optional[str] = Field( default=None, - json_schema_extra={"help": "The training results will be saved at `saves/ray_run_name`."}, + json_schema_extra={ + "help": "The training results will be saved at `saves/ray_run_name`." + }, ) ray_num_workers: int = Field( default=1, - json_schema_extra={"help": "The number of workers for Ray training. Default is 1 worker."}, + json_schema_extra={ + "help": "The number of workers for Ray training. Default is 1 worker." + }, ) resources_per_worker: dict = Field( default_factory=lambda: {"GPU": 1}, - json_schema_extra={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, + json_schema_extra={ + "help": "The resources per worker for Ray training. Default is to use 1 GPU per worker." + }, ) @@ -665,9 +699,9 @@ class AxolotlInputConfig( reward_model: Optional[bool] = None process_reward_model: Optional[bool] = None num_labels: Optional[int] = None - dpo_use_weighting: Optional[bool] = ( - None # whether to use weighting in DPO trainer. If none, default is false in the trainer. - ) + dpo_use_weighting: Optional[ + bool + ] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer. dpo_use_logits_to_keep: Optional[bool] = None datasets: Optional[ @@ -689,7 +723,9 @@ class AxolotlInputConfig( dataset_shard_idx: Optional[int] = None skip_prepare_dataset: Optional[bool] = False - pretraining_dataset: Optional[Annotated[list[Union[PretrainingDataset, SFTDataset]], MinLen(1)]] = Field( + pretraining_dataset: Optional[ + Annotated[list[Union[PretrainingDataset, SFTDataset]], MinLen(1)] + ] = Field( default=None, json_schema_extra={"description": "streaming dataset to use for pretraining"}, ) @@ -744,7 +780,9 @@ class AxolotlInputConfig( # torch_dtype: Optional[torch.dtype] - gradient_checkpointing: Optional[Union[Literal["unsloth", "offload"], bool]] = Field(default=False) + gradient_checkpointing: Optional[ + Union[Literal["unsloth", "offload"], bool] + ] = Field(default=False) gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None unfrozen_parameters: Optional[List[str]] = None @@ -811,7 +849,9 @@ class AxolotlInputConfig( deepspeed: Optional[Union[str, Dict[str, Any]]] = None fsdp: Optional[List[str]] = None fsdp_config: Optional[Dict[str, Any]] = None - fsdp_final_state_dict_type: Optional[Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]] = None + fsdp_final_state_dict_type: Optional[ + Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] + ] = None val_set_size: Optional[float] = Field(default=0.0) @@ -821,7 +861,9 @@ class AxolotlInputConfig( torch_compile: Optional[Union[Literal["auto"], bool]] = None torch_compile_backend: Optional[str] = None - torch_compile_mode: Optional[Literal["default", "reduce-overhead", "max-autotune"]] = None + torch_compile_mode: Optional[ + Literal["default", "reduce-overhead", "max-autotune"] + ] = None max_steps: Optional[int] = None warmup_steps: Optional[int] = None @@ -852,7 +894,9 @@ class AxolotlInputConfig( kto_undesirable_weight: Optional[float] = None rl_beta: Optional[float] = None - max_memory: Optional[Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]] = None + max_memory: Optional[ + Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]] + ] = None gpu_memory_limit: Optional[Union[int, str]] = None low_cpu_mem_usage: Optional[bool] = None @@ -888,7 +932,11 @@ class AxolotlInputConfig( def deprecate_sharegpt_datasets(cls, datasets): for _, ds_cfg in enumerate(datasets): # Handle both dict and pydantic model cases - ds_type = ds_cfg.get("type") if isinstance(ds_cfg, dict) else getattr(ds_cfg, "type", None) + ds_type = ( + ds_cfg.get("type") + if isinstance(ds_cfg, dict) + else getattr(ds_cfg, "type", None) + ) if not ds_type: continue @@ -897,12 +945,16 @@ class AxolotlInputConfig( continue if isinstance(ds_type, str) and ds_type.startswith("sharegpt"): - raise ValueError("`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead.") + raise ValueError( + "`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead." + ) return datasets @field_serializer("datasets") - def datasets_serializer(self, ds_configs: Optional[List[DatasetConfig]]) -> Optional[List[Dict[str, Any]]]: + def datasets_serializer( + self, ds_configs: Optional[List[DatasetConfig]] + ) -> Optional[List[Dict[str, Any]]]: if ds_configs: return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs] return None @@ -968,7 +1020,9 @@ class AxolotlInputConfig( @classmethod def check_sample_packing_w_xformers(cls, data): if data.get("sample_packing") and data.get("xformers_attention"): - raise ValueError("sample_packing not compatible with xformers_attention. Use flash_attention") + raise ValueError( + "sample_packing not compatible with xformers_attention. Use flash_attention" + ) return data @@ -976,8 +1030,12 @@ class AxolotlInputConfig( @classmethod def check_chat_template_config(cls, data): # if chat_template is set to jinja, chat_template_jinja is required - if data.get("chat_template") == ChatTemplate.jinja and not data.get("chat_template_jinja"): - raise ValueError("chat_template_jinja is required when chat_template is set to jinja") + if data.get("chat_template") == ChatTemplate.jinja and not data.get( + "chat_template_jinja" + ): + raise ValueError( + "chat_template_jinja is required when chat_template is set to jinja" + ) # If chat_template_jinja is set, set chat_template to jinja if data.get("chat_template_jinja") and not data.get("chat_template"): @@ -988,8 +1046,14 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_sample_packing_wo_flash(cls, data): - if data.get("sample_packing") and not data.get("flash_attention") and not data.get("sdp_attention"): - LOG.warning("sample_packing without flash_attention or sdp_attention does not handle cross-attention.") + if ( + data.get("sample_packing") + and not data.get("flash_attention") + and not data.get("sdp_attention") + ): + LOG.warning( + "sample_packing without flash_attention or sdp_attention does not handle cross-attention." + ) return data @@ -1028,14 +1092,18 @@ class AxolotlInputConfig( @classmethod def hint_sample_packing_padding(cls, data): if data.get("sample_packing") and not data.get("pad_to_sequence_len"): - LOG.warning("`pad_to_sequence_len: true` is recommended when using sample_packing") + LOG.warning( + "`pad_to_sequence_len: true` is recommended when using sample_packing" + ) return data @model_validator(mode="before") @classmethod def hint_reward_model_pad(cls, data): if data.get("reward_model") and not data.get("pad_to_sequence_len"): - LOG.warning("`pad_to_sequence_len: true` is recommended when using reward_model") + LOG.warning( + "`pad_to_sequence_len: true` is recommended when using reward_model" + ) if data.get("pad_to_sequence_len") is None: data["pad_to_sequence_len"] = True return data @@ -1044,7 +1112,9 @@ class AxolotlInputConfig( @classmethod def check_gas_bsz(cls, data): if data.get("gradient_accumulation_steps") and data.get("batch_size"): - raise ValueError("please set only one of gradient_accumulation_steps or batch_size") + raise ValueError( + "please set only one of gradient_accumulation_steps or batch_size" + ) return data @model_validator(mode="before") @@ -1055,14 +1125,21 @@ class AxolotlInputConfig( and data.get("micro_batch_size") and data.get("eval_batch_size") != data.get("micro_batch_size") ): - LOG.warning("eval_batch_size != micro_batch_size. This can lead to VRAM instability.") + LOG.warning( + "eval_batch_size != micro_batch_size. This can lead to VRAM instability." + ) return data @model_validator(mode="before") @classmethod def check_push_ds_auth(cls, data): - if data.get("push_dataset_to_hub") and data.get("hf_use_auth_token") is not True: - raise ValueError("Require cfg.hf_use_auth_token to be True for push_dataset_to_hub") + if ( + data.get("push_dataset_to_hub") + and data.get("hf_use_auth_token") is not True + ): + raise ValueError( + "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" + ) return data @model_validator(mode="after") @@ -1073,14 +1150,18 @@ class AxolotlInputConfig( @model_validator(mode="after") def check_mpt_checkpointing(self): - if (self.base_model and "mpt" in self.base_model.lower()) and self.gradient_checkpointing: + if ( + self.base_model and "mpt" in self.base_model.lower() + ) and self.gradient_checkpointing: raise ValueError("gradient_checkpointing is not supported for MPT models") return self @model_validator(mode="after") def check_offload_grad_checkpointing(self): if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth": - LOG.warning("`unsloth` is deprecated for gradient_checkpointing, use `offload`") + LOG.warning( + "`unsloth` is deprecated for gradient_checkpointing, use `offload`" + ) self.gradient_checkpointing = "offload" return self @@ -1088,7 +1169,9 @@ class AxolotlInputConfig( def check_better_transformers(self): if self.flash_optimum is True: if self.adapter: - LOG.warning("BetterTransformers probably doesn't work with PEFT adapters") + LOG.warning( + "BetterTransformers probably doesn't work with PEFT adapters" + ) if self.fp16 or self.bf16: raise ValueError("AMP is not supported with BetterTransformer") if self.float16 is not True and self.bfloat16 is not True: @@ -1116,25 +1199,39 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_saves(cls, data): - if data.get("save_strategy") and data.get("save_steps") and data.get("save_strategy") != "steps": + if ( + data.get("save_strategy") + and data.get("save_steps") + and data.get("save_strategy") != "steps" + ): raise ValueError( "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." ) if data.get("saves_per_epoch") and data.get("save_steps"): - raise ValueError("save_steps and saves_per_epoch are mutually exclusive and cannot be used together.") + raise ValueError( + "save_steps and saves_per_epoch are mutually exclusive and cannot be used together." + ) return data @model_validator(mode="before") @classmethod def check_push_save(cls, data): - if data.get("hub_model_id") and (data.get("save_strategy") not in ["steps", "epoch", None]): - LOG.warning("hub_model_id is set without any models being saved. To save a model, set save_strategy.") + if data.get("hub_model_id") and ( + data.get("save_strategy") not in ["steps", "epoch", None] + ): + LOG.warning( + "hub_model_id is set without any models being saved. To save a model, set save_strategy." + ) return data @model_validator(mode="before") @classmethod def check_evals(cls, data): - if data.get("eval_strategy") and data.get("eval_steps") and data.get("eval_strategy") != "steps": + if ( + data.get("eval_strategy") + and data.get("eval_steps") + and data.get("eval_strategy") != "steps" + ): raise ValueError( "eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps." ) @@ -1144,21 +1241,41 @@ class AxolotlInputConfig( and (data.get("eval_steps") or data.get("eval_strategy")) and not data.get("test_datasets") ): - raise ValueError("eval_steps and eval_strategy are not supported with val_set_size == 0") + raise ValueError( + "eval_steps and eval_strategy are not supported with val_set_size == 0" + ) if data.get("evals_per_epoch") and data.get("eval_steps"): - raise ValueError("eval_steps and evals_per_epoch are mutually exclusive and cannot be used together.") - if data.get("evals_per_epoch") and data.get("eval_strategy") and data.get("eval_strategy") != "steps": - raise ValueError("eval_strategy must be empty or set to `steps` when used with evals_per_epoch.") + raise ValueError( + "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together." + ) + if ( + data.get("evals_per_epoch") + and data.get("eval_strategy") + and data.get("eval_strategy") != "steps" + ): + raise ValueError( + "eval_strategy must be empty or set to `steps` when used with evals_per_epoch." + ) - if data.get("do_bench_eval") and not (data.get("evals_per_epoch") or data.get("eval_steps")): - raise ValueError("do_bench_eval requires evals_per_epoch or eval_steps to be set.") + if data.get("do_bench_eval") and not ( + data.get("evals_per_epoch") or data.get("eval_steps") + ): + raise ValueError( + "do_bench_eval requires evals_per_epoch or eval_steps to be set." + ) return data @model_validator(mode="before") @classmethod def check_test_datasets_bench(cls, data): - if data.get("do_bench_eval") and not data.get("test_datasets") and not data.get("val_set_size"): - LOG.warning("`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset.") + if ( + data.get("do_bench_eval") + and not data.get("test_datasets") + and not data.get("val_set_size") + ): + LOG.warning( + "`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset." + ) data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}] return data @@ -1167,12 +1284,22 @@ class AxolotlInputConfig( def check_eval_packing(cls, data): # TODO also should check test_datasets and val_set_size as we can skip # if there are no eval datasets/splits - if data.get("sample_packing") and data.get("eval_table_size") and data.get("eval_sample_packing") is not False: + if ( + data.get("sample_packing") + and data.get("eval_table_size") + and data.get("eval_sample_packing") is not False + ): raise ValueError( "eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false." ) - if data.get("sample_packing") and data.get("eval_sample_packing") is None and not data.get("eval_table_size"): - LOG.info("explicitly setting `eval_sample_packing` to match `sample_packing`") + if ( + data.get("sample_packing") + and data.get("eval_sample_packing") is None + and not data.get("eval_table_size") + ): + LOG.info( + "explicitly setting `eval_sample_packing` to match `sample_packing`" + ) data["eval_sample_packing"] = True if ( @@ -1192,7 +1319,9 @@ class AxolotlInputConfig( def check_mm_prepare(cls, data): if data.get("skip_prepare_dataset"): if data.get("remove_unused_columns") is None: - LOG.info("setting `remove_unused_columns: false` for skip_prepare_dataset") + LOG.info( + "setting `remove_unused_columns: false` for skip_prepare_dataset" + ) data["remove_unused_columns"] = False return data @@ -1233,13 +1362,19 @@ class AxolotlInputConfig( @model_validator(mode="after") def check_simpo_warmup(self): if self.rl == "simpo" and self.warmup_ratio: - raise ValueError("warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead") + raise ValueError( + "warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead" + ) return self @model_validator(mode="before") @classmethod def check_frozen(cls, data): - if data.get("adapter") and data.get("peft_layers_to_transform") and data.get("unfrozen_parameters"): + if ( + data.get("adapter") + and data.get("peft_layers_to_transform") + and data.get("unfrozen_parameters") + ): raise ValueError( "`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior." ) @@ -1250,7 +1385,9 @@ class AxolotlInputConfig( @classmethod def check_peft_layers_pattern(cls, data): if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"): - raise ValueError("peft_layers_pattern requires peft_layers_to_transform to be set") + raise ValueError( + "peft_layers_pattern requires peft_layers_to_transform to be set" + ) return data @model_validator(mode="after") @@ -1273,13 +1410,17 @@ class AxolotlInputConfig( @model_validator(mode="after") def check_fused_lora(self): - if self.adapter in ["lora", "qlora"] and (self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp): + if self.adapter in ["lora", "qlora"] and ( + self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp + ): raise ValueError("Fused modules are not supported with LoRA/QLoRA") return self @model_validator(mode="after") def hint_lora_8bit(self): - loftq = self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits + loftq = ( + self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits + ) if not self.load_in_8bit and self.adapter == "lora" and not loftq: LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") return self @@ -1292,7 +1433,9 @@ class AxolotlInputConfig( "`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps." ) if self.save_steps % self.eval_steps != 0: - raise ValueError("`early_stopping_patience` requires that eval_steps should evenly divide save_steps.") + raise ValueError( + "`early_stopping_patience` requires that eval_steps should evenly divide save_steps." + ) return self @model_validator(mode="after") @@ -1308,7 +1451,9 @@ class AxolotlInputConfig( raise ValueError("deepspeed not supported with ReLoRA") if self.lr_scheduler == "one_cycle": - raise ValueError("ReLoRA is not compatible with the one_cycle scheduler") + raise ValueError( + "ReLoRA is not compatible with the one_cycle scheduler" + ) if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp: raise ValueError("Fused modules are not supported with ReLoRA") @@ -1317,8 +1462,13 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_mem_mismatch(cls, data): - if data.get("max_memory") is not None and data.get("gpu_memory_limit") is not None: - raise ValueError("max_memory and gpu_memory_limit are mutually exclusive and cannot be used together.") + if ( + data.get("max_memory") is not None + and data.get("gpu_memory_limit") is not None + ): + raise ValueError( + "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together." + ) return data @model_validator(mode="before") @@ -1327,10 +1477,13 @@ class AxolotlInputConfig( if ( data.get("unfrozen_parameters") and data.get("gradient_checkpointing_kwargs") - and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is True + and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") + is True ): # https://github.com/huggingface/transformers/issues/21381 - raise ValueError("`use_reentrant` must be false when used with partially frozen model.") + raise ValueError( + "`use_reentrant` must be false when used with partially frozen model." + ) return data @model_validator(mode="before") @@ -1339,7 +1492,8 @@ class AxolotlInputConfig( if ( data.get("adapter") == "qlora" and data.get("gradient_checkpointing_kwargs", {}) - and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is False + and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") + is False and data.get("deepspeed", "") is not None and "zero3" in data.get("deepspeed", "") ): @@ -1356,14 +1510,21 @@ class AxolotlInputConfig( @classmethod def check_val_w_test_datasets(cls, data): if data.get("test_datasets") and data.get("val_set_size"): - raise ValueError("non-zero val_set_size should not be used with test_datasets configuration") + raise ValueError( + "non-zero val_set_size should not be used with test_datasets configuration" + ) return data @model_validator(mode="before") @classmethod def check_eval_strategy(cls, data): - if data.get("evaluation_strategy") is not None and data.get("eval_strategy") is None: - LOG.info("explicitly setting `eval_strategy` from the `evaluation_strategy`") + if ( + data.get("evaluation_strategy") is not None + and data.get("eval_strategy") is None + ): + LOG.info( + "explicitly setting `eval_strategy` from the `evaluation_strategy`" + ) data["eval_strategy"] = data.get("evaluation_strategy") return data @@ -1376,7 +1537,9 @@ class AxolotlInputConfig( and data.get("fsdp_config") and data["fsdp_config"].get("fsdp_offload_params") ): - raise ValueError(f"FSDP Offload not compatible with {data.get('optimizer')}") + raise ValueError( + f"FSDP Offload not compatible with {data.get('optimizer')}" + ) return data @model_validator(mode="before") @@ -1388,21 +1551,27 @@ class AxolotlInputConfig( and data.get("fsdp_config") and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT" ): - raise ValueError("FSDP SHARDED_STATE_DICT not compatible with save_safetensors") + raise ValueError( + "FSDP SHARDED_STATE_DICT not compatible with save_safetensors" + ) return data @model_validator(mode="before") @classmethod def check_causal_lm_evals(cls, data): if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"): - raise ValueError("do_causal_lm_eval is enabled, eval_sample_packing must be set to False") + raise ValueError( + "do_causal_lm_eval is enabled, eval_sample_packing must be set to False" + ) if data.get("eval_causal_lm_metrics"): if not isinstance(data.get("eval_causal_lm_metrics"), list): raise ValueError("eval_causal_lm_metrics must be a list") # only ["sacrebleu", "comet", "ter", "chrf"] supported if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS: - raise ValueError(f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}") + raise ValueError( + f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}" + ) return data @model_validator(mode="before") @@ -1415,14 +1584,22 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_xentropy_patch_conflicts(cls, data): - if data.get("flash_attn_cross_entropy") and data.get("unsloth_cross_entropy_loss"): - raise ValueError("flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled") + if data.get("flash_attn_cross_entropy") and data.get( + "unsloth_cross_entropy_loss" + ): + raise ValueError( + "flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled" + ) return data @model_validator(mode="before") @classmethod def check_qlora_unsloth(cls, data): - if data.get("unsloth_lora_mlp") or data.get("unsloth_lora_qkv") or data.get("unsloth_lora_o"): + if ( + data.get("unsloth_lora_mlp") + or data.get("unsloth_lora_qkv") + or data.get("unsloth_lora_o") + ): if data.get("adapter") == "lora" and data.get("load_in_8bit"): raise ValueError( "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA" @@ -1432,7 +1609,11 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_lora_8bit(cls, data): - if data.get("lora_mlp_kernel") or data.get("lora_qkv_kernel") or data.get("lora_o_kernel"): + if ( + data.get("lora_mlp_kernel") + or data.get("lora_qkv_kernel") + or data.get("lora_o_kernel") + ): if data.get("adapter") == "lora" and data.get("load_in_8bit"): raise ValueError( "lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA" @@ -1442,8 +1623,13 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_lora_axolotl_unsloth(cls, data): - is_lora_kernel = any(data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]) - is_unsloth_lora = any(data.get(k) for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]) + is_lora_kernel = any( + data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"] + ) + is_unsloth_lora = any( + data.get(k) + for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"] + ) if is_lora_kernel and is_unsloth_lora: raise ValueError( "both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)" @@ -1454,7 +1640,9 @@ class AxolotlInputConfig( @classmethod def check_torch_compile_deepspeed(cls, data): if data.get("deepspeed") and data.get("torch_compile"): - raise ValueError("torch_compile should be set within your deepspeed config file") + raise ValueError( + "torch_compile should be set within your deepspeed config file" + ) return data @model_validator(mode="before") @@ -1538,9 +1726,15 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): def check_bf16(self): if self.capabilities.bf16: if not self.bf16 and not self.bfloat16: - LOG.info("bf16 support detected, but not enabled for this configuration.") + LOG.info( + "bf16 support detected, but not enabled for this configuration." + ) else: - if not self.merge_lora and not self.is_preprocess and (self.bf16 is True or self.bfloat16 is True): + if ( + not self.merge_lora + and not self.is_preprocess + and (self.bf16 is True or self.bfloat16 is True) + ): raise ValueError( "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above." ) @@ -1549,7 +1743,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): @model_validator(mode="before") @classmethod def check_sample_packing_w_sdpa_bf16(cls, data): - is_sm_90: bool = data["capabilities"] and data["capabilities"].get("compute_capability") == "sm_90" + is_sm_90: bool = ( + data["capabilities"] + and data["capabilities"].get("compute_capability") == "sm_90" + ) if ( data.get("sample_packing") and data.get("sdp_attention") @@ -1574,7 +1771,11 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): @model_validator(mode="before") @classmethod def check_multigpu_unsloth(cls, data): - if data.get("unsloth_lora_mlp") or data.get("unsloth_lora_qkv") or data.get("unsloth_lora_o"): + if ( + data.get("unsloth_lora_mlp") + or data.get("unsloth_lora_qkv") + or data.get("unsloth_lora_o") + ): capabilities = data.get("capabilities") if capabilities and capabilities.get("n_gpu", 0) > 1: raise ValueError( @@ -1585,7 +1786,11 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): @model_validator(mode="before") @classmethod def check_multigpu_lora_kernels(cls, data): - if data.get("lora_mlp_kernel") or data.get("lora_qkv_kernel") or data.get("lora_o_kernel"): + if ( + data.get("lora_mlp_kernel") + or data.get("lora_qkv_kernel") + or data.get("lora_o_kernel") + ): capabilities = data.get("capabilities") is_fsdp = data.get("fsdp") is not None is_deepspeed = data.get("deepspeed") is not None @@ -1614,7 +1819,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): torch_version = str(torch.__version__).split("+", maxsplit=1)[0] if version.parse(torch_version) < version.parse("2.5.1"): - raise ValueError("ADOPT optimizer is incompatible with torch version < 2.5.1") + raise ValueError( + "ADOPT optimizer is incompatible with torch version < 2.5.1" + ) return data @model_validator(mode="before") @@ -1623,8 +1830,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): if data.get("torch_compile") == "auto": env_capabilities = data.get("env_capabilities", {}) if env_capabilities.get("torch_version"): - if version.parse(env_capabilities.get("torch_version")) >= version.parse("2.5.1"): - LOG.info("torch.compile is available, setting torch_compile to True") + if version.parse( + env_capabilities.get("torch_version") + ) >= version.parse("2.5.1"): + LOG.info( + "torch.compile is available, setting torch_compile to True" + ) data["torch_compile"] = True else: data["torch_compile"] = False @@ -1689,13 +1900,16 @@ def handle_legacy_message_fields_logic(data: dict) -> dict: ) if ( "content" in data["message_property_mappings"] - and data["message_property_mappings"]["content"] != data["message_field_content"] + and data["message_property_mappings"]["content"] + != data["message_field_content"] ): raise ValueError( f"Conflicting message content fields: message_field_content='{data['message_field_content']}' " f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'" ) - data["message_property_mappings"]["content"] = data["message_field_content"] or "content" + data["message_property_mappings"]["content"] = ( + data["message_field_content"] or "content" + ) del data["message_field_content"] elif "content" not in data["message_property_mappings"]: