diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd index dbb365f73..d839ce211 100644 --- a/docs/multimodal.qmd +++ b/docs/multimodal.qmd @@ -13,10 +13,13 @@ format: - [Pixtral](#sec-pixtral) - [Llava-1.5](#sec-llava-15) - [Mistral-Small-3.1](#sec-mistral-small-31) +- [Voxtral](#sec-voxtral) - [Gemma-3](#sec-gemma-3) - [Gemma-3n](#sec-gemma-3n) - [Qwen2-VL](#sec-qwen2-vl) - [Qwen2.5-VL](#sec-qwen25-vl) +- [SmolVLM2](#sec-smolvlm2) +- [LFM2-VL](#sec-lfm2-vl) ## Usage @@ -31,7 +34,7 @@ skip_prepare_dataset: true remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training sample_packing: false # not yet supported with multimodal -chat_template: # see in next section +chat_template: # see in next section if specified # example dataset datasets: @@ -97,6 +100,16 @@ base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503 chat_template: mistral_v7_tekken ``` +### Voxtral {#sec-voxtral} + +::: {.callout-tip} +Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral_common[audio]==1.8.3'` +::: + +```yaml +base_model: mistralai/Voxtral-Mini-3B-2507 +``` + ### Gemma-3 {#sec-gemma-3} ::: {.callout-tip} @@ -143,6 +156,26 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct chat_template: qwen2_vl # same as qwen2-vl ``` +### SmolVLM2 {#sec-smolvlm2} + +::: {.callout-tip} +Please make sure to install `num2words` via `pip3 install num2words==0.5.14` +::: + +```yaml +base_model: HuggingFaceTB/SmolVLM2-500M-Video-Instruct +``` + +### LFM2-VL {#sec-lfm2-vl} + +::: {.callout-warning} +Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d` +::: + +```yaml +base_model: LiquidAI/LFM2-VL-450M +``` + ## Dataset Format For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format. @@ -181,6 +214,20 @@ You may need to install `librosa` via `pip3 install librosa==0.11.0`. ::: +### Video + +::: {.callout-warning} + +This is not well tested at the moment. We welcome contributors! + +::: + +For video loading, you can use the following keys within `content` alongside `"type": "video"`: + +- `"path": "/path/to/video.mp4"` +- `"url": "https://example.com/video.mp4"` +- `"video": np.ndarray | list[PIL.Image.Image] | torch.Tensor` (or list of the aforementioned) + ### Example Here is an example of a multi-modal dataset: diff --git a/examples/LiquidAI/README.md b/examples/LiquidAI/README.md new file mode 100644 index 000000000..96fc74a92 --- /dev/null +++ b/examples/LiquidAI/README.md @@ -0,0 +1,58 @@ +# Finetune Liquid Foundation Models 2 (LFM2) with Axolotl + +[Liquid Foundation Models 2 (LFM2)](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) are a family of small, open-weight models from [Liquid AI](https://www.liquid.ai/) focused on quality, speed, and memory efficiency. Liquid AI released text-only [LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) and text+vision [LFM2-VL](https://huggingface.co/collections/LiquidAI/lfm2-vl-68963bbc84a610f7638d5ffa) models. + +LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-range convolutions, and grouped query attention, enabling fast training and inference. + +This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl. + +## Getting Started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + + Here is an example of how to install from pip: + ```bash + # Ensure you have a compatible version of Pytorch installed + pip3 install packaging setuptools wheel ninja + pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' + ``` + +2. Run one of the finetuning examples below. + + **LFM2** + ```bash + # FFT SFT (1x48GB @ 25GiB) + axolotl train examples/LiquidAI/lfm2-350m-fft.yaml + ``` + + **LFM2-VL** + ```bash + # LoRA SFT (1x48GB @ 2.7GiB) + axolotl train examples/LiquidAI/lfm2-vl-lora.yaml + ``` + +### TIPS + +- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it: + ```bash + pip uninstall -y causal-conv1d + ``` + +- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html). +- **Dataset Formats**: + - For LFM2 models, the dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + - For LFM2-VL models, Axolotl follows the multi-content Messages format. See our [Multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format) for details. + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) + +## Related Resources + +- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models) +- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/lfm2/lfm2-350m-fft.yaml b/examples/LiquidAI/lfm2-350m-fft.yaml similarity index 96% rename from examples/lfm2/lfm2-350m-fft.yaml rename to examples/LiquidAI/lfm2-350m-fft.yaml index 16a0a028e..d19815008 100644 --- a/examples/lfm2/lfm2-350m-fft.yaml +++ b/examples/LiquidAI/lfm2-350m-fft.yaml @@ -2,7 +2,6 @@ base_model: LiquidAI/LFM2-350M chunked_cross_entropy: true -chat_template: tokenizer_default eot_tokens: - "<|im_end|>" datasets: diff --git a/examples/LiquidAI/lfm2-vl-lora.yaml b/examples/LiquidAI/lfm2-vl-lora.yaml new file mode 100644 index 000000000..7fee17f92 --- /dev/null +++ b/examples/LiquidAI/lfm2-vl-lora.yaml @@ -0,0 +1,58 @@ +base_model: LiquidAI/LFM2-VL-450M +trust_remote_code: true +model_type: AutoModelForImageTextToText +processor_type: AutoProcessor + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/lfm2/README.md b/examples/lfm2/README.md deleted file mode 100644 index eb9ca911f..000000000 --- a/examples/lfm2/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# Liquid Foundation Models 2 - -LFM2 support in transformers exists in the main branch, but is not yet included in the transformers release. - -```bash -pip install --upgrade --no-deps --force-reinstall git+https://github.com/huggingface/transformers.git -``` diff --git a/examples/smolvlm2/README.md b/examples/smolvlm2/README.md new file mode 100644 index 000000000..9c0ae4836 --- /dev/null +++ b/examples/smolvlm2/README.md @@ -0,0 +1,49 @@ +# Finetune SmolVLM2 with Axolotl + +[SmolVLM2](https://huggingface.co/collections/HuggingFaceTB/smolvlm2-smallest-video-lm-ever-67ab6b5e84bf8aaa60cb17c7) are a family of lightweight, open-source multimodal models from HuggingFace designed to analyze and understand video, image, and text content. + +These models are built for efficiency, making them well-suited for on-device applications where computational resources are limited. Models are available in multiple sizes, including 2.2B, 500M, and 256M. + +This guide shows how to fine-tune SmolVLM2 models with Axolotl. + +## Getting Started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + + Here is an example of how to install from pip: + ```bash + # Ensure you have a compatible version of Pytorch installed + pip3 install packaging setuptools wheel ninja + pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' + ``` + +2. Install an extra dependency: + + ```bash + pip3 install num2words==0.5.14 + ``` + +3. Run the finetuning example: + + ```bash + # LoRA SFT (1x48GB @ 6.8GiB) + axolotl train examples/smolvlm2/smolvlm2-2B-lora.yaml + ``` + +## TIPS + +- **Dataset Format**: For video finetuning, your dataset must be compatible with the multi-content Messages format. For more details, see our documentation on [Multimodal Formats](https://docs.axolotl.ai/docs/multimodal.html#dataset-format). +- **Dataset Loading**: Read more on how to prepare and load your own datasets in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html). + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) + +## Related Resources + +- [SmolVLM2 Blog](https://huggingface.co/blog/smolvlm2) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/smolvlm2/smolvlm2-2B-lora.yaml b/examples/smolvlm2/smolvlm2-2B-lora.yaml new file mode 100644 index 000000000..1aeff408d --- /dev/null +++ b/examples/smolvlm2/smolvlm2-2B-lora.yaml @@ -0,0 +1,56 @@ +base_model: HuggingFaceTB/SmolVLM2-2.2B-Instruct +trust_remote_code: true +processor_type: AutoProcessor + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.text_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/src/axolotl/loaders/constants.py b/src/axolotl/loaders/constants.py index 3fabf9d94..4939cb28d 100644 --- a/src/axolotl/loaders/constants.py +++ b/src/axolotl/loaders/constants.py @@ -1,26 +1,13 @@ """Shared constants for axolotl.loaders module""" -from transformers import ( - Gemma3ForConditionalGeneration, - Gemma3nForConditionalGeneration, - Llama4ForConditionalGeneration, - LlavaForConditionalGeneration, - Mistral3ForConditionalGeneration, - MllamaForConditionalGeneration, - Qwen2_5_VLForConditionalGeneration, - Qwen2VLForConditionalGeneration, +from transformers import AutoModelForImageTextToText +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, ) -MULTIMODAL_AUTO_MODEL_MAPPING = { - "mllama": MllamaForConditionalGeneration, - "llama4": Llama4ForConditionalGeneration, - "llava": LlavaForConditionalGeneration, - "qwen2_vl": Qwen2VLForConditionalGeneration, - "qwen2_5_vl": Qwen2_5_VLForConditionalGeneration, - "mistral3": Mistral3ForConditionalGeneration, - "gemma3": Gemma3ForConditionalGeneration, - "gemma3n": Gemma3nForConditionalGeneration, -} +MULTIMODAL_AUTO_MODEL_MAPPING = dict(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES) + +MULTIMODAL_AUTO_MODEL_MAPPING["lfm2-vl"] = AutoModelForImageTextToText try: from transformers import VoxtralForConditionalGeneration diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 6bf1f149b..53ae428a2 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -25,6 +25,7 @@ from peft import ( from torch.distributed import DeviceMesh from transformers import ( AutoModelForCausalLM, + AutoModelForImageTextToText, AutoModelForVision2Seq, AwqConfig, BitsAndBytesConfig, @@ -212,6 +213,7 @@ class ModelLoader: self.model_kwargs["use_kernels"] = self.cfg.use_kernels self._set_quantization_config() self._set_attention_config() + self._check_model_requirements() def _apply_post_model_load_setup(self): """Configure the model after it has been loaded.""" @@ -432,6 +434,8 @@ class ModelLoader: self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get( self.model_config.model_type, AutoModelForVision2Seq ) + if isinstance(self.auto_model_loader, str): + self.auto_model_loader = AutoModelForImageTextToText def _set_device_map_config(self): """Setup `device_map` according to config""" @@ -628,6 +632,16 @@ class ModelLoader: if self.cfg.low_cpu_mem_usage: self.model_kwargs["low_cpu_mem_usage"] = True + def _check_model_requirements(self): + if self.cfg.model_config_type in ["lfm2-vl", "lfm2"]: + from transformers.utils.import_utils import is_causal_conv1d_available + + if is_causal_conv1d_available(): + raise ImportError( + "The 'causal-conv1d' package is installed but causes compatibility issues with LFM2 models. " + "Please uninstall it by running: `pip uninstall -y causal-conv1d`" + ) + def _configure_zero3_memory_efficient_loading( self, ) -> HfTrainerDeepSpeedConfig | None: diff --git a/src/axolotl/processing_strategies.py b/src/axolotl/processing_strategies.py index 4cc5e85a1..31597d5a6 100644 --- a/src/axolotl/processing_strategies.py +++ b/src/axolotl/processing_strategies.py @@ -6,7 +6,7 @@ from typing import Optional from PIL import Image, ImageOps from PIL.Image import Resampling from torch import Tensor, zeros_like -from transformers import ProcessorMixin, VoxtralProcessor +from transformers import ProcessorMixin, SmolVLMProcessor, VoxtralProcessor from transformers.image_utils import load_image from axolotl.utils.dict import remove_none_values @@ -138,7 +138,7 @@ class ProcessingStrategy: image_key = key break - # if the image key exists, add the image to the first message + # if the image key exists, add the image to the first user message if image_key is not None and processed_example[image_key] is not None: # TODO: check if it's normal to be single image only for common datasets # From observation, it's usually a list of single image but some datasets may have several columns for images @@ -179,26 +179,34 @@ class ProcessingStrategy: # Look for any image type in the first message # some dataset have an {type: "image"} in the first message + msg_ind_to_add = None ind_to_add = None + first_user_idx = None - for i, content in enumerate( - processed_example["messages"][0]["content"] - ): - # Usually datasets created with image columns, don't have it in the messages itself - if content["type"] == "image" and all( - k not in content for k in ["image", "url", "path", "base64"] + for msg_idx, msg_content in enumerate(processed_example["messages"]): + if first_user_idx is None and msg_content["role"] == "user": + first_user_idx = msg_idx + for i, content in enumerate( + processed_example["messages"][msg_idx]["content"] ): - ind_to_add = i - break + # Usually datasets created with image columns, don't have it in the messages itself + if content["type"] == "image" and all( + k not in content for k in ["image", "url", "path", "base64"] + ): + msg_ind_to_add = msg_idx + ind_to_add = i + break # If an image type is found, add the image to that index - if ind_to_add is not None: - processed_example["messages"][0]["content"][ind_to_add][ - "image" - ] = image_value + if ind_to_add is not None and msg_ind_to_add is not None: + processed_example["messages"][msg_ind_to_add]["content"][ + ind_to_add + ]["image"] = image_value else: - # if no image type is found, add it to end of the first message - processed_example["messages"][0]["content"].append( + # if no image type is found, add it to end of the first user message + if first_user_idx is None: + first_user_idx = 0 + processed_example["messages"][first_user_idx]["content"].append( { "type": "image", "image": image_value, @@ -395,6 +403,24 @@ class VoxtralProcessingStrategy(ProcessingStrategy): return labels +class SmolVLM2ProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for SmolVLM2""" + + def __init__( + self, + processor: ProcessorMixin, + chat_template: Optional[str] = None, + image_size: int | tuple[int, int] | None = None, + image_resize_algorithm: Resampling | None = None, + ): + super().__init__(processor, chat_template, image_size, image_resize_algorithm) + self.image_token = "" # nosec + + self.image_token_id = processor.tokenizer.additional_special_tokens_ids[ + processor.tokenizer.additional_special_tokens.index(self.image_token) + ] + + def get_processing_strategy( processor: ProcessorMixin, chat_template, @@ -402,32 +428,43 @@ def get_processing_strategy( image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, ): + processing_kwargs = { + "processor": processor, + "chat_template": chat_template, + "image_size": image_size, + "image_resize_algorithm": image_resize_algorithm, + } + + if chat_template_type in [None, "tokenizer_default"] and hasattr( + processor.tokenizer, "chat_template" + ): + processing_kwargs["chat_template"] = processor.tokenizer.chat_template + if chat_template_type == "qwen2_vl": return Qwen2VLProcessingStrategy( - processor, chat_template, image_size, image_resize_algorithm + **processing_kwargs, ) if chat_template_type == "gemma3": return Gemma3ProcessingStrategy( - processor, chat_template, image_size, image_resize_algorithm + **processing_kwargs, ) if chat_template_type == "gemma3n": return Gemma3nProcessingStrategy( - processor, chat_template, image_size, image_resize_algorithm - ) - if chat_template_type in [ - "llama3_2_vision", - "llama4", - "llava", - "mistral_v7_tekken", - "pixtral", - ]: - return ProcessingStrategy( - processor, chat_template, image_size, image_resize_algorithm + **processing_kwargs, ) if isinstance(processor, VoxtralProcessor): return VoxtralProcessingStrategy( - processor, chat_template, image_size, image_resize_algorithm + **processing_kwargs, ) - raise ValueError(f"Unsupported chat template type: {chat_template_type}") + if isinstance(processor, SmolVLMProcessor): + return SmolVLM2ProcessingStrategy( + **processing_kwargs, + ) + + # llama3_2_vision, llama4, llava + # mistral_v7_tekken, pixtral, lfm2vl + return ProcessingStrategy( + **processing_kwargs, + ) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 8241dd385..f927b7fcb 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -129,13 +129,21 @@ class ChatTemplatePrompter(Prompter): images=images, return_tensors="pt", ) + if hasattr(batch, "to_dict"): + batch = batch.to_dict() + else: + batch = dict(batch) + # workaround since processor works in batches instead of single examples + out = {} for k, val in batch.items(): - if k in ["pixel_values"]: - batch[k] = val.tolist() + if hasattr(val, "tolist"): + out[k] = ( + val.tolist() if k == "pixel_values" else val.squeeze(0).tolist() + ) else: - batch[k] = val.squeeze().tolist() - return batch + out[k] = val + return out return self.tokenizer.apply_chat_template( conversation, @@ -433,10 +441,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): tokenized_prompt["attention_mask"] = [1] * len(input_ids) else: input_ids = tokenized_res["input_ids"] - tokenized_prompt = tokenized_res + tokenized_prompt = dict(tokenized_res) if not self.train_on_inputs: - user_prompt_len = len(prompt_ids) + if isinstance(prompt_ids, dict): + user_prompt_len = len(prompt_ids["input_ids"]) + else: + user_prompt_len = len(prompt_ids) labels = [-100] * user_prompt_len + input_ids[user_prompt_len:] else: labels = input_ids diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index 0075d4830..542918527 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -5,7 +5,6 @@ Collators for multi-modal chat messages and packing from dataclasses import dataclass from typing import Any, Optional, Union -import torch from torch import Tensor from transformers import PreTrainedTokenizerBase from transformers.data.data_collator import DataCollatorMixin @@ -42,62 +41,19 @@ class MultiModalChatDataCollator(DataCollatorMixin): examples = self.processing_strategy(examples) # Initialize batch - batch: dict[str, Any] = {} + messages = [ex["messages"] for ex in examples] - # Process each example - for example in examples: - # Apply chat template to process the example - # This method requires transformers>=4.49.0 - result = self.processing_strategy.processor.apply_chat_template( - example["messages"], - add_generation_prompt=False, - tokenize=True, - return_tensors="pt", - padding=True, - return_dict=True, - chat_template=self.processing_strategy.chat_template, - ) - - # TODO: Check if need handling for len(input_ids) > sequence_len - - # Add the processed tensors to our batch - for key in result.keys(): - if key not in batch: - batch[key] = [] - - batch[key].append(result[key].squeeze(0)) - - # Pad sequences to the same length - input_ids = torch.nn.utils.rnn.pad_sequence( - batch["input_ids"], - batch_first=True, - padding_value=self.tokenizer.pad_token_id, + batch = self.processing_strategy.processor.apply_chat_template( + messages, + add_generation_prompt=False, + tokenize=True, + return_tensors="pt", + padding=True, + return_dict=True, + chat_template=self.processing_strategy.chat_template, ) - attention_mask = torch.nn.utils.rnn.pad_sequence( - batch["attention_mask"], batch_first=True, padding_value=0 - ) - - # Create the final batch - final_batch = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - - for key, val in batch.items(): - if key in ["input_ids", "attention_mask"]: - continue - - if key in ["token_type_ids", "cross_attention_mask"]: - final_batch[key] = torch.nn.utils.rnn.pad_sequence( - val, batch_first=True, padding_value=0 - ) - else: - final_batch[key] = torch.stack(val) - # Process the labels - final_batch["labels"] = self.processing_strategy.process_labels( - final_batch["input_ids"] - ) + batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"]) - return final_batch + return batch diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 5931fe148..939ed5c1c 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -147,7 +147,11 @@ def require_hopper(test_case): def check_tensorboard( - temp_run_dir: str, tag: str, lt_val: float, assertion_err: str + temp_run_dir: str, + tag: str, + lt_val: float, + assertion_err: str, + rtol: float = 0.02, ) -> None: """ helper function to parse and check tensorboard logs @@ -157,6 +161,7 @@ def check_tensorboard( reader = SummaryReader(event_file) df = reader.scalars # pylint: disable=invalid-name df = df[(df.tag == tag)] # pylint: disable=invalid-name + lt_val = (1 + rtol) * lt_val if "%s" in assertion_err: assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1] else: