diff --git a/_quarto.yml b/_quarto.yml index c564fb0dd..0a8e023cf 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -133,6 +133,7 @@ quartodoc: - utils.schemas.datasets - utils.schemas.peft - utils.schemas.trl + - utils.schemas.multimodal - utils.schemas.integrations - utils.schemas.enums - utils.schemas.utils diff --git a/docs/config.qmd b/docs/config.qmd index f166f8050..2a79a0126 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -586,6 +586,14 @@ resume_from_checkpoint: # Be careful with this being turned on between different models. auto_resume_from_checkpoints: false +## Multimodal section +# int | tuple[int, int] | None . Size to resize images to, width x height. +# Will read from model/processor config if not set. +image_size: +# str. Algorithm to use for image resizing. "bilinear", "bicubic", "lanczos". Default is "bilinear". +image_resize_algorithm: 'bilinear' +## End of multimodal section + # Don't mess with this, it's here for accelerate and torchrun local_rank: diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd index 2381566ad..e8e793482 100644 --- a/docs/multimodal.qmd +++ b/docs/multimodal.qmd @@ -1,28 +1,171 @@ -# MultiModal / Vision Language Models (BETA) +--- +title: MultiModal / Vision Language Models (BETA) +format: + html: + toc: true + toc-depth: 3 +--- -### Supported Models +## Supported Models -- Mllama, i.e. llama with vision models +- [Mllama](#sec-mllama) +- [Pixtral](#sec-pixtral) +- [Llava-1.5](#sec-llava-15) +- [Mistral-Small-3.1](#sec-mistral-small-31) +- [Gemma-3](#sec-gemma-3) +- [Qwen2-VL](#sec-qwen2-vl) +- [Qwen2.5-VL](#sec-qwen25-vl) -### Usage +## Usage -Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA, -you'll need to use the following in YAML in combination with the rest of the required hyperparams. +Multimodal support is limited and doesn't have full feature parity. + +Here are the hyperparams you'll need to use to finetune a multimodal model. ```yaml -base_model: alpindale/Llama-3.2-11B-Vision-Instruct processor_type: AutoProcessor -skip_prepare_dataset: true -chat_template: llama3_2_vision +skip_prepare_dataset: true +remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training +sample_packing: false # not yet supported with multimodal + +chat_template: # see in next section + +# example dataset datasets: - path: HuggingFaceH4/llava-instruct-mix-vsft type: chat_template split: train[:1%] field_messages: messages -remove_unused_columns: false -sample_packing: false -# only finetune the Language model, leave the vision model and vision tower frozen +# (optional) if doing lora, only finetune the Language model, +# leave the vision model and vision tower frozen +# load_in_8bit: true +adapter: lora lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +# (optional) if you want to resize images to a set size +image_size: 512 +image_resize_algorithm: bilinear +``` + +Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs. + +::: {.callout-warning} +Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs. +::: + +### Mllama {#sec-mllama} + +```yaml +base_model: meta-llama/Llama-3.2-11B-Vision-Instruct + +chat_template: llama3_2_vision +``` + +### Pixtral {#sec-pixtral} + +```yaml +base_model: mistralai/Pixtral-12B-2409 + +chat_template: pixtral +``` + +### Llava-1.5 {#sec-llava-15} + +```yaml +base_model: llava-hf/llava-1.5-7b-hf + +chat_template: llava +``` + +### Mistral-Small-3.1 {#sec-mistral-small-31} + +```yaml +base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503 + +chat_template: mistral_v7_tekken +``` + +### Gemma-3 {#sec-gemma-3} + +::: {.callout-tip} +The Gemma3-1B model is a text-only model, so please train as regular text model. +::: + +For multi-modal 4B/12B/27B models, use the following config: + +```yaml +base_model: google/gemma-3-4b-it + +chat_template: gemma3 +``` + +### Qwen2-VL {#sec-qwen2-vl} + +```yaml +base_model: Qwen/Qwen2-VL-7B-Instruct + +chat_template: qwen2_vl +``` + +### Qwen2.5-VL {#sec-qwen25-vl} + +```yaml +base_model: Qwen/Qwen2.5-VL-7B-Instruct + +chat_template: qwen2_vl # same as qwen2-vl +``` + +## Dataset Format + +For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format. + +- A message is a list of `role` and `content`. +- `role` can be `system`, `user`, `assistant`, etc. +- `content` is a list of `type` and (`text` or `image` or `path` or `url` or `base64`). + +::: {.callout-note} +For backwards compatibility: + +- If the dataset has a `images` or `image` column of `list[Image]`, it will be appended to the first `content` list as `{"type": "image", "image": ...}`. However, if the content already has a `{"type": "image"}` but no `image` key, it will be set the `image` key. +- If `content` is a string, it will be converted to a list with `type` as `text`. +::: + +::: {.callout-tip} +For image loading, you can use the following keys within `content` alongside `"type": "image"`: + +- `"path": "/path/to/image.jpg"` +- `"url": "https://example.com/image.jpg"` +- `"base64": "..."` +- `"image": PIL.Image` +::: + +Here is an example of a multi-modal dataset: +```json +[ + { + "messages": [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."} + ] + }, + { + "role": "user", + "content": [ + {"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"}, + {"type": "text", "text": "Describe this image in detail."} + ] + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "The image is a bee."} + ] + } + ] + } +] ``` diff --git a/examples/gemma3/gemma-3-4b-lora.yml b/examples/gemma3/gemma-3-4b-lora.yml new file mode 100644 index 000000000..b85392982 --- /dev/null +++ b/examples/gemma3/gemma-3-4b-lora.yml @@ -0,0 +1,63 @@ +base_model: google/gemma-3-4b-it +processor_type: AutoProcessor +strict: false + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +chat_template: gemma3 +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 2048 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +local_rank: +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/examples/llava/lora-7b.yaml b/examples/llava/lora-7b.yaml new file mode 100644 index 000000000..9129d0122 --- /dev/null +++ b/examples/llava/lora-7b.yaml @@ -0,0 +1,63 @@ +base_model: llava-hf/llava-1.5-7b-hf +processor_type: AutoProcessor +strict: false + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +chat_template: llava +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +local_rank: +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/examples/mistral/mistral-small-3.1-24B-lora.yml b/examples/mistral/mistral-small-3.1-24B-lora.yml new file mode 100644 index 000000000..177484799 --- /dev/null +++ b/examples/mistral/mistral-small-3.1-24B-lora.yml @@ -0,0 +1,66 @@ +base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503 +processor_type: AutoProcessor +strict: false + +load_in_8bit: true + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +chat_template: mistral_v7_tekken +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 2048 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +local_rank: +logging_steps: 1 +flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet. +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/examples/pixtral/lora-12b.yml b/examples/pixtral/lora-12b.yml new file mode 100644 index 000000000..7336a7ad0 --- /dev/null +++ b/examples/pixtral/lora-12b.yml @@ -0,0 +1,65 @@ +base_model: mistral-community/pixtral-12b +processor_type: AutoProcessor +strict: false + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +chat_template: pixtral +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +local_rank: +logging_steps: 1 +flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + pad_token: diff --git a/examples/qwen2-vl/lora-7b.yaml b/examples/qwen2-vl/lora-7b.yaml new file mode 100644 index 000000000..e7ab13ddb --- /dev/null +++ b/examples/qwen2-vl/lora-7b.yaml @@ -0,0 +1,63 @@ +base_model: Qwen/Qwen2-VL-7B-Instruct +processor_type: AutoProcessor +strict: false + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +chat_template: qwen2_vl +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +local_rank: +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b151be8fa..b237b1ef3 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -60,6 +60,7 @@ from axolotl.core.training_args import ( from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback +from axolotl.processing_strategies import get_processing_strategy from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( EvalFirstStepCallback, @@ -747,6 +748,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.accelerator_config ) + if self.cfg.image_size: + training_arguments_kwargs["image_size"] = self.cfg.image_size + if self.cfg.image_resize_algorithm: + training_arguments_kwargs["image_resize_algorithm"] = ( + self.cfg.image_resize_algorithm + ) if self.cfg.kd_ce_alpha is not None: training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha if self.cfg.kd_alpha is not None: @@ -890,8 +897,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): else: if self.cfg.processor_type and self.processor: collator = MultiModalChatDataCollator - kwargs["processor"] = self.processor - kwargs["chat_template"] = training_args.chat_template + kwargs["processing_strategy"] = get_processing_strategy( + self.processor, + training_args.chat_template, + self.cfg.chat_template, + image_size=training_args.image_size, + image_resize_algorithm=training_args.image_resize_algorithm, + ) elif self.cfg.batch_flattening: collator = DataCollatorWithFlattening collator_args.pop(0) diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 82a62c049..fbb363492 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -5,6 +5,7 @@ extra axolotl specific training args from dataclasses import dataclass, field from typing import Optional +from PIL.Image import Resampling from transformers import TrainingArguments from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig @@ -212,6 +213,20 @@ class AxolotlTrainingMixins: metadata={"help": "The number of workers to use in sequence parallelism"}, ) + # multi-modal section + + image_size: int | tuple[int, int] | None = field( + default=None, + metadata={"help": "The size of the image to resize to"}, + ) + + image_resize_algorithm: Resampling | None = field( + default=None, + metadata={"help": "The algorithm to use for image resizing"}, + ) + + # end of multi-modal section + @dataclass class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): diff --git a/src/axolotl/processing_strategies.py b/src/axolotl/processing_strategies.py new file mode 100644 index 000000000..0b854af8d --- /dev/null +++ b/src/axolotl/processing_strategies.py @@ -0,0 +1,278 @@ +"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types""" + +from copy import deepcopy +from typing import Optional + +from PIL import Image, ImageOps +from PIL.Image import Resampling +from torch import Tensor +from transformers import ProcessorMixin +from transformers.image_utils import load_image + + +class ProcessingStrategy: + """Base Processing Strategy class""" + + def __init__( + self, + processor: ProcessorMixin, + chat_template: Optional[str] = None, + image_size: int | tuple[int, int] | None = None, + image_resize_algorithm: Resampling | None = None, + ): + self.processor = processor + self.chat_template = chat_template + self.image_token = None + self.image_token_id = None + + self.image_size = image_size + self.image_resize_algorithm = ( + image_resize_algorithm or Image.Resampling.BILINEAR + ) + + if hasattr(processor, "image_token"): + self.image_token = processor.image_token + self.image_token_id = processor.tokenizer.convert_tokens_to_ids( + self.image_token + ) + + def __call__(self, examples: list[dict]) -> list[dict]: + """ + Preprocess conversation examples to ensure consistent format. + Converts different conversation formats to OpenAI format with 'messages'. + Supports two formats: + 1. OpenAI format with 'messages' + 2. Legacy format with 'conversations' + + Args: + examples: list of conversation dictionaries + + Returns: + list of dicts in OpenAI format with 'messages' key + + Raises: + ValueError: If the conversation format is not supported + """ + role_mapping = { + "human": "user", + "gpt": "assistant", + } + + def normalize_role(role: str) -> str: + """Normalize role names to OpenAI format. Default to original role if not found.""" + return role_mapping.get(role, role) + + def convert_legacy_format(example: dict) -> dict: + """Convert legacy 'conversations' format to OpenAI 'messages' format.""" + messages = [ + {"role": normalize_role(convo["from"]), "content": convo["value"]} + for convo in example["conversations"] + ] + + # Create new dict without 'conversations' key + result = deepcopy(example) + result.pop("conversations") + result["messages"] = messages + return result + + def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]: + """Convert regular messages format to Messages format with content type""" + + new_messages = [] + for message in messages: + if isinstance(message["content"], str): + new_messages.append( + { + "role": message["role"], + "content": [ + { + "type": "text", + "text": message["content"], + } + ], + } + ) + elif isinstance(message["content"], list): + content = message["content"] + + new_messages.append( + { + "role": message["role"], + "content": content, + } + ) + + return new_messages + + processed_examples = [] + for example in examples: + if not ("messages" in example or "conversations" in example): + raise ValueError( + "Only `messages` and `conversations` message keys are currently supported." + ) + + processed_example = None + if "messages" in example: # OpenAI format + processed_example = example + else: # Legacy format + processed_example = convert_legacy_format(example) + + # convert regular messages format to Messages format with content type + # for compatibility with apply_chat_template + processed_example["messages"] = convert_messages_to_multimedia_messages( + processed_example["messages"] + ) + + # find the image key if it exists + possible_image_keys = ["images", "image"] + image_key = None + for key in possible_image_keys: + if key in processed_example: + image_key = key + break + + # if the image key exists, add the image to the first message + if image_key is not None: + # TODO: check if it's normal to be single image only for common datasets + # From observation, it's usually a list of single image but some datasets may have several columns for images + # Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages + image_value = processed_example[image_key][0] + + # Handle image loading (Image, url, path, base64) + image_value = load_image(image_value) + + if self.image_size is not None: + assert hasattr( + image_value, "resize" + ), "Image does not have a resize method" + + if isinstance(self.image_size, tuple): + image_value = image_value.resize( + self.image_size, self.image_resize_algorithm + ) + else: + # Set the padding value; here we use black (0, 0, 0) for RGB images + padding_color = (0, 0, 0) + + # When image_size is an int (square target), preserve aspect ratio then pad + # This is to prevent aspect ratio distortion when resizing to square + image_value = ImageOps.pad( + image_value, + (self.image_size, self.image_size), + method=self.image_resize_algorithm, + color=padding_color, + ) + + # Look for any image type in the first message + # some dataset have an {type: "image"} in the first message + ind_to_add = None + + for i, content in enumerate( + processed_example["messages"][0]["content"] + ): + # Usually datasets created with image columns, don't have it in the messages itself + if content["type"] == "image" and all( + k not in content for k in ["image", "url", "path", "base64"] + ): + ind_to_add = i + break + + # If an image type is found, add the image to that index + if ind_to_add is not None: + processed_example["messages"][0]["content"][ind_to_add][ + "image" + ] = image_value + else: + # if no image type is found, add it to end of the first message + processed_example["messages"][0]["content"].append( + { + "type": "image", + "image": image_value, + } + ) + + processed_examples.append(processed_example) + + return processed_examples + + def process_labels(self, input_ids: Tensor) -> Tensor: + labels = input_ids.clone() + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels[labels == self.processor.tokenizer.pad_token_id] = -100 + + # Ignore the image token index in the loss computation (model specific) + labels[labels == self.image_token_id] = -100 + + return labels + + +class Qwen2VLProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Qwen2-VL""" + + def __init__( + self, + processor: ProcessorMixin, + chat_template: Optional[str] = None, + image_size: int | tuple[int, int] | None = None, + image_resize_algorithm: Resampling | None = None, + ): + super().__init__(processor, chat_template, image_size, image_resize_algorithm) + self.image_token = "<|image_pad|>" # nosec + self.image_token_id = processor.tokenizer.convert_tokens_to_ids( + self.image_token + ) + + +class Gemma3ProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Gemma3""" + + def __init__( + self, + processor: ProcessorMixin, + chat_template: Optional[str] = None, + image_size: int | tuple[int, int] | None = None, + image_resize_algorithm: Resampling | None = None, + ): + super().__init__(processor, chat_template, image_size, image_resize_algorithm) + self.image_token = processor.tokenizer.special_tokens_map["boi_token"] + self.image_token_id = processor.tokenizer.convert_tokens_to_ids( + self.image_token + ) + + def process_labels(self, input_ids): + labels = input_ids.clone() + + # Follows https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora + labels[labels == self.processor.tokenizer.pad_token_id] = -100 + labels[labels == self.image_token_id] = -100 + labels[labels == 262144] = -100 # corresponds to + + return labels + + +def get_processing_strategy( + processor: ProcessorMixin, + chat_template, + chat_template_type, + image_size: int | tuple[int, int] | None = None, + image_resize_algorithm: Resampling | None = None, +): + if chat_template_type == "qwen2_vl": + return Qwen2VLProcessingStrategy( + processor, chat_template, image_size, image_resize_algorithm + ) + if chat_template_type == "gemma3": + return Gemma3ProcessingStrategy( + processor, chat_template, image_size, image_resize_algorithm + ) + if chat_template_type in [ + "llama3_2_vision", + "llava", + "mistral_v7_tekken", + "pixtral", + ]: + return ProcessingStrategy( + processor, chat_template, image_size, image_resize_algorithm + ) + raise ValueError(f"Unsupported chat template type: {chat_template_type}") diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 7dbeda462..ba0516eb9 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -20,12 +20,14 @@ _CHAT_TEMPLATES = { "mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1... "mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large... "mistral_v3_tekken": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3-Tekken: Nemo, Pixtral... + "mistral_v7_tekken": "{%- set today = strftime_now(\"%Y-%m-%d\") %}\n{%- set default_system_message = \"You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.\\nYour knowledge base was last updated on 2023-10-01. The current date is \" + today + \".\\n\\nWhen you're not sure about some information, you say that you don't have the information and don't make up anything.\\nIf the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. \\\"What are some good restaurants around me?\\\" => \\\"Where are you?\\\" or \\\"When is the next flight to Tokyo\\\" => \\\"Where do you travel from?\\\")\" %}\n\n{{- bos_token }}\n\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set system_message = default_system_message %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}\n\n{%- for message in loop_messages %}\n {%- if message['role'] == 'user' %}\n\t {%- if message['content'] is string %}\n {{- '[INST]' + message['content'] + '[/INST]' }}\n\t {%- else %}\n\t\t {{- '[INST]' }}\n\t\t {%- for block in message['content'] %}\n\t\t\t {%- if block['type'] == 'text' %}\n\t\t\t\t {{- block['text'] }}\n\t\t\t {%- elif block['type'] == 'image' or block['type'] == 'image_url' %}\n\t\t\t\t {{- '[IMG]' }}\n\t\t\t\t{%- else %}\n\t\t\t\t {{- raise_exception('Only text and image blocks are supported in message content!') }}\n\t\t\t\t{%- endif %}\n\t\t\t{%- endfor %}\n\t\t {{- '[/INST]' }}\n\t\t{%- endif %}\n {%- elif message['role'] == 'system' %}\n {{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }}\n {%- elif message['role'] == 'assistant' %}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- elif message['content'] is iterable %} \n\t\t {%- for block in message['content'] %}\n\t\t\t {%- if block['type'] == 'text' %}\n\t\t\t\t {{- block['text'] }}\n\t\t\t {%- else %}\n\t\t\t\t {{- raise_exception('Only text blocks are supported in assistant message content!') }} {%- endif %}\n\t\t\t \n\t\t\t{%- endfor %} {{- eos_token }} {%- else %}\n {{- raise_exception('Unsupported assistant message content format!') }} \n{%- endif %} \n{%- else %}\n {{- raise_exception('Only user, system and assistant roles are supported!') }}\n {%- endif %}\n{%- endfor %}", "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", - "gemma3_text": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'model\n'}}\n{%- endif -%}\n", + "gemma3": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'model\n'}}\n{%- endif -%}\n", "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", "llama3_2_vision": '{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now("%d %b %Y") %}\n {%- else %}\n {%- set date_string = "26 Jul 2024" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == "" %}\n {{- raise_exception("Prompting with images is incompatible with system messages.") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n {%- if tools is not none %}\n {{- "Environment: ipython\\n" }}\n {%- endif %}\n {{- "Cutting Knowledge Date: December 2023\\n" }}\n {{- "Today Date: " + date_string + "\\n\\n" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- "<|eot_id|>" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\' }}\n {%- if message[\'content\'] is string %}\n {{- message[\'content\'] }}\n {%- else %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {{- \'<|image|>\' }}\n {%- elif content[\'type\'] == \'text\' %}\n {{- content[\'text\'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {{- "<|eot_id|>" }}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n', + "llava": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}", "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}", "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", @@ -34,6 +36,8 @@ _CHAT_TEMPLATES = { "qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", "exaone": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '[|system|][|endofturn|]\n' }}{% endif %}{{ '[|' + message['role'] + '|]' + message['content'] }}{% if message['role'] == 'user' %}{{ '\n' }}{% else %}{{ '[|endofturn|]\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[|assistant|]' }}{% endif %}", "metharme": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>' + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %}", + "pixtral": '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}\n {%- endif %}\n {%- if message["role"] == "user" %}\n {%- if loop.last and system_message is defined %}\n {{- "[INST]" + system_message + "\n\n" }}\n {%- else %}\n {{- "[INST]" }}\n {%- endif %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["text"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message["content"] }}\n {%- endif %}\n {{- "[/INST]" }}\n {%- elif message["role"] == "assistant" %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["text"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n{%- endif %}\n{%- endfor %}\n{{- eos_token }}\n{%- else %}\n{{- message["content"] + eos_token }}\n{%- endif %}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}', + "qwen2_vl": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}", } diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index 6f8a64ad8..75d72f8dc 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -2,15 +2,17 @@ Collators for multi-modal chat messages and packing """ -from copy import deepcopy from dataclasses import dataclass from typing import Any, Optional, Union -from PIL import Image -from transformers import PreTrainedTokenizerBase, ProcessorMixin +import torch +from torch import Tensor +from transformers import PreTrainedTokenizerBase from transformers.data.data_collator import DataCollatorMixin from transformers.utils import PaddingStrategy +from axolotl.processing_strategies import ProcessingStrategy + @dataclass class MultiModalChatDataCollator(DataCollatorMixin): @@ -19,11 +21,9 @@ class MultiModalChatDataCollator(DataCollatorMixin): """ tokenizer: PreTrainedTokenizerBase - processor: ProcessorMixin - return_tensors: str = "pt" - chat_template: Optional[str] = None + processing_strategy: ProcessingStrategy packing: bool = False - max_images: int = -1 + return_tensors: str = "pt" padding: Union[bool, str, PaddingStrategy] = True pad_to_multiple_of: Optional[int] = None @@ -31,162 +31,62 @@ class MultiModalChatDataCollator(DataCollatorMixin): if self.packing: raise ValueError("Packing is currently not supported.") - def torch_call( - self, examples: list[Union[list[int], Any, dict[str, Any]]] - ) -> dict[str, Any]: - # Handle dict or lists with proper padding and conversion to tensor. - - return self.__class__.process_rows( - examples, self.processor, self.chat_template, self.max_images - ) - - @staticmethod - def process_rows(examples, processor, chat_template, max_images, length_only=False): - # HINT: use `_torch_collate_batch` to stack and pad tensors - # see also DataCollatorWithFlattening and DefaultDataCollator - - # *** This is COPIED from the trl example sft_vlm.py code *** - # use this as a starting point - - def _preprocess(examples: list[dict]) -> list[dict]: - """ - Preprocess conversation examples to ensure consistent format. - - Converts different conversation formats to OpenAI format with 'messages'. - Supports two formats: - 1. OpenAI format with 'messages' - 2. Legacy format with 'conversations' - - Args: - examples: list of conversation dictionaries - - Returns: - dict in OpenAI format with 'messages' key - - Raises: - ValueError: If the conversation format is not supported - """ - role_mapping = { - "human": "user", - "gpt": "assistant", - } - - def normalize_role(role: str) -> str: - """Normalize role names to OpenAI format. Default to original role if not found.""" - return role_mapping.get(role, role) - - def convert_legacy_format(example: dict) -> dict: - """Convert legacy 'conversations' format to OpenAI 'messages' format.""" - messages = [ - { - "role": normalize_role(convo["from"]), - "content": convo["value"], - } - for convo in example["conversations"] - ] - - # Create new dict without 'conversations' key - result = deepcopy(example) - result.pop("conversations") - return {"messages": messages, **result} - - processed_examples = [] - for example in examples: - # OpenAI format - if "messages" in example: - processed_examples.append(example) - - # Legacy format - elif "conversations" in example: - processed_examples.append(convert_legacy_format(example)) - - else: - raise ValueError( - "Only `messages` and `conversations` message keys are currently supported." - ) - - return processed_examples - - def _process_images(examples, max_images): - """ - Process images from examples, ensuring consistency in image presence and applying max_images limit. - - Args: - examples: List of dictionaries that may contain 'images' key - max_images: Maximum number of images to keep per example (0 means no limit) - - Returns: - Either None (if no images) or List[Image objects] (if all examples have images) - - Raises: - ValueError: If there's a mix of None and non-None images - """ - - def get_image(example): - if "images" not in example: - return None - images = example["images"] - if isinstance(images, str): - return Image.open(images) - return images - - images = [get_image(example) for example in examples] - - # Count None and non-None images - none_count = sum(1 for img in images if img is None) - - # All images are None - if none_count == len(images): - return None - - # Mix of None and non-None images - if none_count > 0: - raise ValueError( - "All images should be either None or not None. " - "Please provide images for all examples or None." - ) - - # Apply max_images limit if specified - if max_images > 0: - images = [ - ( - img_batch[:max_images] - if isinstance(img_batch, (list, tuple)) - else img_batch - ) - for img_batch in images - ] - - return images + def torch_call(self, examples: list[dict]) -> dict[str, Any]: + return self.process_rows(examples) + def process_rows( + self, + examples: list[dict], + ) -> dict[str, Tensor]: # Preprocess the examples - examples = _preprocess(examples) + examples = self.processing_strategy(examples) - # Get the texts and images, and apply the chat template - texts = [ - processor.apply_chat_template( - example["messages"], chat_template=chat_template, tokenize=False + # Initialize batch + batch: dict[str, Any] = {} + + # Process each example + for example in examples: + # Apply chat template to process the example + # This method requires transformers>=4.49.0 + result = self.processing_strategy.processor.apply_chat_template( + example["messages"], + add_generation_prompt=True, + tokenize=True, + return_tensors="pt", + padding=True, + return_dict=True, + chat_template=self.processing_strategy.chat_template, ) - for example in examples - ] - images = _process_images(examples, max_images=max_images) + # TODO: Check if need handling for len(input_ids) > sequence_len - # Tokenize the texts and process the images - batch = processor(text=texts, images=images, return_tensors="pt", padding=True) + # Add the processed tensors to our batch + for key in result.keys(): + if key not in batch: + batch[key] = [] - # The labels are the input_ids, and we mask the padding tokens in the loss computation - labels = batch["input_ids"].clone() - labels[labels == processor.tokenizer.pad_token_id] = -100 # - # Ignore the image token index in the loss computation (model specific) - image_token_id = processor.tokenizer.convert_tokens_to_ids( - processor.image_token + batch[key].append(result[key].squeeze(0)) + + # Pad sequences to the same length + input_ids = torch.nn.utils.rnn.pad_sequence( + batch["input_ids"], + batch_first=True, + padding_value=self.tokenizer.pad_token_id, ) - labels[labels == image_token_id] = -100 - batch["labels"] = labels - if length_only: - return { - "length": [len(sample["input_ids"]) for sample in batch["input_ids"]] - } - return batch + attention_mask = torch.nn.utils.rnn.pad_sequence( + batch["attention_mask"], batch_first=True, padding_value=0 + ) + + # Create the final batch + final_batch = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + # Process the labels + final_batch["labels"] = self.processing_strategy.process_labels( + final_batch["input_ids"] + ) + + return final_batch diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 4e956140d..634575066 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -13,7 +13,7 @@ from axolotl.integrations.base import PluginManager from axolotl.integrations.config import merge_input_args from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model_config +from axolotl.utils.models import MULTIMODAL_AUTO_MODEL_MAPPING, load_model_config from axolotl.utils.schemas.config import ( AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, ) @@ -158,7 +158,7 @@ def normalize_config(cfg): cfg.is_multimodal = ( hasattr(model_config, "model_type") - and model_config.model_type in ["llava", "mllama"] + and model_config.model_type in MULTIMODAL_AUTO_MODEL_MAPPING or any( multimodal_name in cfg.base_model.lower() for multimodal_name in [ @@ -171,7 +171,6 @@ def normalize_config(cfg): cfg.processor_config = ( cfg.processor_config or cfg.base_model_config or cfg.base_model ) - model_config = model_config.text_config cfg.model_config_type = model_config.model_type diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 83f70a022..23a6e102f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -34,12 +34,16 @@ from transformers import ( # noqa: F401 AutoTokenizer, AwqConfig, BitsAndBytesConfig, + Gemma3ForConditionalGeneration, GPTQConfig, LlavaForConditionalGeneration, + Mistral3ForConditionalGeneration, MllamaForConditionalGeneration, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, + Qwen2_5_VLForConditionalGeneration, + Qwen2VLForConditionalGeneration, ) from transformers.integrations.deepspeed import ( HfTrainerDeepSpeedConfig, @@ -69,9 +73,13 @@ from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_mod LOG = logging.getLogger(__name__) -MULTIMODEL_AUTO_MODEL_MAPPING = { - "llava": LlavaForConditionalGeneration, +MULTIMODAL_AUTO_MODEL_MAPPING = { "mllama": MllamaForConditionalGeneration, + "llava": LlavaForConditionalGeneration, + "qwen2_vl": Qwen2VLForConditionalGeneration, + "qwen2_5_vl": Qwen2_5_VLForConditionalGeneration, + "mistral3": Mistral3ForConditionalGeneration, + "gemma3": Gemma3ForConditionalGeneration, } @@ -101,7 +109,21 @@ def get_module_class_from_name(module, name): def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]): if cfg.is_multimodal: - model_config = model_config.text_config + if hasattr(model_config, "text_config"): + model_config = model_config.text_config + model_config.use_cache = False + elif hasattr(model_config, "get_text_config"): + model_config = model_config.get_text_config() + model_config.use_cache = False + + # check if image_size is not set and load image size from model config if available + if ( + cfg.image_size is None + and hasattr(model_config, "vision_config") + and hasattr(model_config.vision_config, "image_size") + ): + cfg.image_size = model_config.vision_config.image_size + LOG.debug(f"Loaded image size: {cfg.image_size} from model config") quant_config_exists = ( hasattr(model_config, "quantization_config") @@ -440,6 +462,31 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): **processor_kwargs, ) + # Attempt to load image size from processor if available + if ( + cfg.image_size is None + and hasattr(processor, "size") + and any(dim in processor.size for dim in ["width", "height"]) + ): + im_width = None + im_height = None + if "width" in processor.size: + im_width = processor.size["width"] + if "height" in processor.size: + im_height = processor.size["height"] + + # If both width and height are set, use a tuple + if im_width is not None and im_height is not None: + cfg.image_size = (im_width, im_height) + # If only width is set, use as integer + elif im_width is not None: + cfg.image_size = im_width + # If only height is set, use as integer + elif im_height is not None: + cfg.image_size = im_height + + LOG.debug(f"Loaded image size: {cfg.image_size} from processor") + return processor @@ -477,7 +524,11 @@ class ModelLoader: # init model config self.model_config = load_model_config(cfg) if cfg.is_multimodal: - self.text_model_config = self.model_config.text_config + if hasattr(self.model_config, "text_config"): + self.text_model_config = self.model_config.text_config + else: + # for qwen2_vl + self.text_model_config = self.model_config.get_text_config() else: self.text_model_config = self.model_config @@ -673,7 +724,7 @@ class ModelLoader: should be set according to the type of the model. """ if self.cfg.is_multimodal: - self.auto_model_loader = MULTIMODEL_AUTO_MODEL_MAPPING.get( + self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get( self.model_config.model_type, AutoModelForVision2Seq ) @@ -1194,7 +1245,9 @@ class ModelLoader: ) ): resize_kwargs = {} - if self.cfg.mean_resizing_embeddings is not None: + if self.cfg.mean_resizing_embeddings is not None and not ( + self.model_config.model_type == "llava" + ): resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings self.model.resize_token_embeddings(embeddings_len, **resize_kwargs) else: diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 7992e6559..d52146092 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -42,6 +42,7 @@ from axolotl.utils.schemas.model import ( ModelOutputConfig, SpecialTokensConfig, ) +from axolotl.utils.schemas.multimodal import MultiModalConfig from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig from axolotl.utils.schemas.training import HyperparametersConfig from axolotl.utils.schemas.trl import TRLConfig @@ -64,6 +65,7 @@ class AxolotlInputConfig( LISAConfig, GradioConfig, RayConfig, + MultiModalConfig, RemappedParameters, DeprecatedParameters, BaseModel, diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index a0c6df710..ad735afe0 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -22,8 +22,8 @@ class ChatTemplate(str, Enum): mistral_v1 = "mistral_v1" # pylint: disable=invalid-name mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name + mistral_v7_tekken = "mistral_v7_tekken" # pylint: disable=invalid-name gemma = "gemma" # pylint: disable=invalid-name - gemma3_text = "gemma3_text" # pylint: disable=invalid-name cohere = "cohere" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name @@ -37,6 +37,10 @@ class ChatTemplate(str, Enum): tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name exaone = "exaone" # pylint: disable=invalid-name metharme = "metharme" # pylint: disable=invalid-name + pixtral = "pixtral" # pylint: disable=invalid-name + llava = "llava" # pylint: disable=invalid-name + qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name + gemma3 = "gemma3" # pylint: disable=invalid-name class CustomSupportedOptimizers(str, Enum): diff --git a/src/axolotl/utils/schemas/multimodal.py b/src/axolotl/utils/schemas/multimodal.py new file mode 100644 index 000000000..a3449199f --- /dev/null +++ b/src/axolotl/utils/schemas/multimodal.py @@ -0,0 +1,48 @@ +"""Pydantic models for multimodal-related configuration""" + +from typing import Literal + +from PIL.Image import Resampling +from pydantic import BaseModel, Field, field_validator + + +class MultiModalConfig(BaseModel): + """Multi-modal configuration subset""" + + image_size: int | tuple[int, int] | None = Field( + default=None, + json_schema_extra={ + "description": ( + "The size of the image to resize to. It can be an integer (resized into padded-square image) or a tuple (width, height)." + "If not provided, we will attempt to load from preprocessor.size, otherwise, images won't be resized." + ) + }, + ) + image_resize_algorithm: ( + Literal["bilinear", "bicubic", "lanczos"] | Resampling | None + ) = Field( + default=None, + json_schema_extra={ + "description": "The resampling algorithm to use for image resizing. Default is bilinear. Please refer to PIL.Image.Resampling for more details." + }, + ) + + @field_validator("image_resize_algorithm", mode="before") + @classmethod + def convert_image_resize_algorithm(cls, image_resize_algorithm): + """ + Convert the image resize algorithm to a PIL.Image.Resampling enum. + """ + if isinstance(image_resize_algorithm, str): + image_resize_algorithm = image_resize_algorithm.lower() + if image_resize_algorithm == "bilinear": + image_resize_algorithm = Resampling.BILINEAR + elif image_resize_algorithm == "bicubic": + image_resize_algorithm = Resampling.BICUBIC + elif image_resize_algorithm == "lanczos": + image_resize_algorithm = Resampling.LANCZOS + else: + raise ValueError( + f"Invalid image resize algorithm: {image_resize_algorithm}" + ) + return image_resize_algorithm diff --git a/tests/e2e/test_gemma3_text.py b/tests/e2e/test_gemma3_text.py index 6ce360f68..14423ce73 100644 --- a/tests/e2e/test_gemma3_text.py +++ b/tests/e2e/test_gemma3_text.py @@ -58,7 +58,7 @@ class TestGemma3Text: "bos_token": "", "eos_token": "", }, - "chat_template": "gemma3_text", + "chat_template": "gemma3", "num_epochs": 1, "micro_batch_size": 1, "gradient_accumulation_steps": 4, @@ -105,7 +105,7 @@ class TestGemma3Text: "split": "train[:1%]", }, ], - "chat_template": "gemma3_text", + "chat_template": "gemma3", "special_tokens": { "bos_token": "", "eos_token": "",