diff --git a/examples/olmo3/README.md b/examples/olmo3/README.md index 2f98eb73e..160223628 100644 --- a/examples/olmo3/README.md +++ b/examples/olmo3/README.md @@ -16,7 +16,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations axolotl train examples/olmo3/olmo3-7b-qlora.yaml ``` -Let us know how it goes. Happy finetuning! 🚀 +This uses about 11.3 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀 ### TIPS diff --git a/examples/olmo3/olmo3-7b-qlora.yaml b/examples/olmo3/olmo3-7b-qlora.yaml index c8878d79f..de2bf1d3d 100644 --- a/examples/olmo3/olmo3-7b-qlora.yaml +++ b/examples/olmo3/olmo3-7b-qlora.yaml @@ -42,10 +42,10 @@ wandb_watch: wandb_name: wandb_log_model: -gradient_accumulation_steps: 4 +gradient_accumulation_steps: 2 micro_batch_size: 2 num_epochs: 1 -optimizer: adamw_bnb_8bit +optimizer: adamw_8bit lr_scheduler: cosine learning_rate: 0.0002 diff --git a/src/axolotl/core/attention/flex_block_mask.py b/src/axolotl/core/attention/flex_block_mask.py deleted file mode 100644 index 37149983c..000000000 --- a/src/axolotl/core/attention/flex_block_mask.py +++ /dev/null @@ -1,158 +0,0 @@ -""" -monkeypatch for flex + packing -""" - -import sys -from typing import Callable, Optional, Union - -import torch -from torch.nn.attention.flex_attention import BlockMask -from transformers import Cache, PretrainedConfig -from transformers.masking_utils import ( - ALL_MASK_ATTENTION_FUNCTIONS, - _preprocess_mask_arguments, - and_masks, - causal_mask_function, - or_masks, -) -from transformers.utils import is_torch_greater_or_equal - -_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True) - - -def create_causal_mask( - config: PretrainedConfig, - input_embeds: torch.Tensor, - attention_mask: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Optional[Cache], - or_mask_function: Optional[Callable] = None, - and_mask_function: Optional[Callable] = None, -) -> Optional[Union[torch.Tensor, BlockMask]]: - """ - Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values` - has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align - to what is needed in the `modeling_xxx.py` files). - - Args: - config (`PretrainedConfig`): - The model config. - input_embeds (`torch.Tensor`): - The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the - batch size, query length and dtype. - attention_mask (`torch.Tensor`, optional): - The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length). - It can also be an already prepared 4D mask, in which case it is returned as-is. - cache_position (`torch.Tensor`): - A tensor of shape (query_length,) indicating the current indices of the input sequence elements. - past_key_values (`Cache`, optional): - The past key values, if we use a cache. - or_mask_function (`Callable`, optional): - An optional mask function to combine with the causal mask function (by doing the union of both). This is - useful to easily overlay another mask on top of the causal one, for example for image tokens handling. - and_mask_function (`Callable`, optional): - An optional mask function to combine with the causal mask function (by doing the intersection of both). This is - useful to easily overlay another mask on top of the causal one, for example for image tokens handling. - """ - # If we have an HybridCache structure, here we want to create the mask for the full layers - if ( - past_key_values - and hasattr(past_key_values, "is_sliding") - and False in past_key_values.is_sliding - ): - layer_idx = past_key_values.is_sliding.index(False) - else: - layer_idx = 0 - - original_attention_mask = ( - None - if attention_mask is None - else attention_mask.clone().to(cache_position.device) - ) - early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( - config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx - ) - if early_exit: - return attention_mask - - batch_size, total_seq_len = cache_position.shape - key_length = total_seq_len - document_ids = torch.nn.functional.pad( - original_attention_mask, value=0, pad=(0, key_length) - ) - - batch_size, dtype = input_embeds.shape[0], input_embeds.dtype - if attention_mask is not None: - - def causal_doc_mask_mod(batch_idx, head_idx, q_idx, kv_idx): - """ - Defines the logic of a block causal mask by combining both a standard causal mask - and a block diagonal document mask. - See :func:`~torchtune.modules.attention_utils.create_block_causal_mask` - for an illustration. - """ - causal_mask_ = q_idx >= kv_idx # not valid when decoding - document_mask = ( - document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx] - ) - final_mask = causal_mask_ & document_mask - return final_mask - - mask_factory_function = causal_doc_mask_mod - else: - mask_factory_function = causal_mask_function - mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] - - # Do not allow skip if we are compiling (this is to match BC) - allow_is_causal_skip = ( - not past_key_values.is_compileable if past_key_values is not None else True - ) - - # Allow slight deviations from causal mask - if or_mask_function is not None: - if not _is_torch_greater_or_equal_than_2_6: - raise ValueError( - "Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6" - ) - mask_factory_function = or_masks(mask_factory_function, or_mask_function) - allow_is_causal_skip = False - if and_mask_function is not None: - if not _is_torch_greater_or_equal_than_2_6: - raise ValueError( - "Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6" - ) - mask_factory_function = and_masks(mask_factory_function, and_mask_function) - allow_is_causal_skip = False - - # We now create the mask - causal_mask = mask_interface( - batch_size=batch_size, - cache_position=cache_position, - kv_length=kv_length, - kv_offset=kv_offset, - mask_function=mask_factory_function, - attention_mask=attention_mask, - allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa - dtype=dtype, # Additional kwarg for eager - config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface - ) - return causal_mask - - -def patch_create_causal_mask(model_type): - import transformers.masking_utils - - transformers.masking_utils.create_causal_mask = create_causal_mask - - if model_type: - try: - # Dynamically import the module and attention class - module_path = f"transformers.models.{model_type}.modeling_{model_type}" - module = __import__(module_path) - module.create_causal_mask = create_causal_mask - del sys.modules[module_path] - except (ImportError, AttributeError) as e: - raise ValueError( - f"Could not import attention class for model_type: {model_type}. " - f"Error: {str(e)}" - ) from e diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index c31982262..64f363bb1 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -199,12 +199,6 @@ class PatchManager: flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {} patch_flex_wrapper(**flex_attn_compile_kwargs) - if self.cfg.sample_packing: - from axolotl.core.attention.flex_block_mask import ( - patch_create_causal_mask, - ) - - patch_create_causal_mask(self.cfg.model_config_type) def _apply_model_specific_patches(self): """Apply patches specific to model architectures.""" diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index 35810897a..7aaedef28 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -2,9 +2,17 @@ import functools import logging +import warnings from axolotl.utils.distributed import is_main_process +# Suppress noisy bitsandbytes warnings about dtype casting during quantization +warnings.filterwarnings( + "ignore", + message=".*MatMul8bitLt: inputs will be cast from.*", + category=UserWarning, +) + # Adapted from Accelerate # https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 36565fb03..cb834a3bf 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -801,6 +801,36 @@ class OptimizationValidationMixin: ) return data + @model_validator(mode="before") + @classmethod + def check_cross_entropy_conflicts(cls, data): + """Check for mutual exclusivity between cross entropy patch options. + + Only one of the following can be enabled at a time: + - cut_cross_entropy (CutCrossEntropyPlugin) + - chunked_cross_entropy + - liger_cross_entropy (LigerPlugin) + - liger_fused_linear_cross_entropy (LigerPlugin) + """ + ce_options = { + "cut_cross_entropy": data.get("cut_cross_entropy"), + "chunked_cross_entropy": data.get("chunked_cross_entropy"), + "liger_cross_entropy": data.get("liger_cross_entropy"), + "liger_fused_linear_cross_entropy": data.get( + "liger_fused_linear_cross_entropy" + ), + } + + enabled_options = [k for k, v in ce_options.items() if v] + + if len(enabled_options) > 1: + raise ValueError( + f"Only one cross entropy optimization can be enabled at a time. " + f"Found {len(enabled_options)} enabled: {', '.join(enabled_options)}. " + "Please disable all but one." + ) + return data + @model_validator(mode="before") @classmethod def check_fsdp_version(cls, data):