diff --git a/docs/docker.qmd b/docs/docker.qmd index 2908592fa..6bdd77a54 100644 --- a/docs/docker.qmd +++ b/docs/docker.qmd @@ -103,8 +103,7 @@ This uses the same tags as the [`main` image](#sec-main-tags). - `JUPYTER_DISABLE`: Disable Jupyter lab. - `JUPYTER_PASSWORD`: Set a password for the Jupyter lab. -- `PUBLIC_KEY`: Add a public key for the SSH service. -- `SSH_KEY`: Add a private key for the SSH service. +- `PUBLIC_KEY` / `SSH_KEY`: Add a public key for the SSH service. #### Volume mounts diff --git a/examples/cohere/command-r-7b-qlora.yml b/examples/cohere/command-r-7b-qlora.yml new file mode 100644 index 000000000..2ac5c4c09 --- /dev/null +++ b/examples/cohere/command-r-7b-qlora.yml @@ -0,0 +1,71 @@ +base_model: CohereForAI/c4ai-command-r7b-12-2024 +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer + +load_in_8bit: false +load_in_4bit: true +strict: false + +# huggingface repo +chat_template: cohere +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true + +sequence_len: 2048 +sample_packing: true +eval_sample_packing: false +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: true + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index e11a39bd6..f4fbdbed4 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -56,7 +56,7 @@ def do_inference( cfg: Dictionary mapping `axolotl` config keys to values. cli_args: Inference-specific CLI arguments. """ - model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True) + model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True) prompter = cli_args.prompter prompter_module = None @@ -151,7 +151,7 @@ def do_inference_gradio( """ import gradio as gr - model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True) + model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True) prompter = cli_args.prompter prompter_module = None diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 595eb3eab..2a3343b6b 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -27,7 +27,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None: """ print_axolotl_text_art() - model, tokenizer = load_model_and_tokenizer(cfg=cfg) + model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg) safe_serialization = cfg.save_safetensors is True LOG.info("Running merge of LoRA with base model...") @@ -44,6 +44,9 @@ def do_merge_lora(*, cfg: DictDefault) -> None: ) tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) + if processor: + processor.save_pretrained(str(Path(cfg.output_dir) / "merged")) + def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: """ diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index cb61fa371..7cc4d2744 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -13,11 +13,16 @@ from typing import Any, Callable, Type, Union, get_args, get_origin import click import requests from pydantic import BaseModel -from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import ( + PreTrainedModel, + PreTrainedTokenizer, + PreTrainedTokenizerFast, + ProcessorMixin, +) from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_tokenizer +from axolotl.utils.models import load_model, load_processor, load_tokenizer configure_logging() LOG = logging.getLogger(__name__) @@ -295,9 +300,13 @@ def load_model_and_tokenizer( *, cfg: DictDefault, inference: bool = False, -) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]: +) -> tuple[ + PreTrainedModel, + PreTrainedTokenizer | PreTrainedTokenizerFast | Any, + ProcessorMixin | None, +]: """ - Helper function for loading a model and tokenizer specified in the given `axolotl` + Helper function for loading a model, tokenizer, and processor specified in the given `axolotl` config. Args: @@ -305,7 +314,7 @@ def load_model_and_tokenizer( inference: Boolean denoting inference mode. Returns: - `transformers` model and tokenizer. + Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin). """ LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") tokenizer = load_tokenizer(cfg) @@ -313,4 +322,9 @@ def load_model_and_tokenizer( LOG.info("loading model...") model, _ = load_model(cfg, tokenizer, inference=inference) - return model, tokenizer + processor = None + if cfg.is_multimodal: + LOG.info("loading processor...") + processor = load_processor(cfg, tokenizer) + + return model, tokenizer, processor diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index b166a3004..7b428eb58 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -1,6 +1,6 @@ # Cut Cross Entropy -Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy operation during loss calculation. +Cut Cross Entropy (CCE) reduces VRAM usage through optimization on the cross-entropy operation during loss calculation. See https://github.com/apple/ml-cross-entropy @@ -29,6 +29,20 @@ plugins: cut_cross_entropy: true ``` +## Supported Models + +- llama +- phi3 +- gemma +- gemma2 +- gemma3 +- gemma3_text +- mistral +- mistral3 +- qwen2 +- cohere +- cohere2 + ## Citation ```bib diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index 516e9a2ae..a475cd9f7 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -72,7 +72,9 @@ class CutCrossEntropyPlugin(BasePlugin): if cfg.cut_cross_entropy: self._check_requirements() - from cut_cross_entropy.transformers import cce_patch + from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import ( + cce_patch, + ) with zero_only(): LOG.info( diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py new file mode 100644 index 000000000..5cdc53b0c --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py @@ -0,0 +1,201 @@ +"""Cohere and Cohere2 CCE patch.""" + +# This patch is based off transformers 4.50.0. +# It patches the forward function for CohereForCausalLM and Cohere2ForCausalLM. +# It scales the hidden states by the logit scale in advance instead of the logits as the +# operation is done internally and should be mathematically equivalent. + +# pylint: disable=duplicate-code + +from types import MethodType +from typing import Optional, Tuple, Union + +import torch +import transformers +from cut_cross_entropy.transformers.utils import ( + PatchOptions, + TransformersModelT, + apply_lce, +) +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.cohere.modeling_cohere import ( + _CONFIG_FOR_DOC, + COHERE_INPUTS_DOCSTRING, + KwargsForCausalLM, +) +from transformers.processing_utils import Unpack +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg + +_PATCH_OPTS: PatchOptions | None = None + + +@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") +@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def cce_forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >> from transformers import AutoTokenizer, CohereForCausalLM + + >> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01") + >> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01") + + >> prompt = "Hey, are you conscious? Can you talk to me?" + >> inputs = tokenizer(prompt, return_tensors="pt") + + >> # Generate + >> generate_ids = model.generate(inputs.input_ids, max_length=30) + >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + loss = None + logits = None + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + + if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): + assert labels is not None + # scale weight by logit_scale in-place of logits + loss = apply_lce( + hidden_states[:, slice_indices, :], + self.lm_head.weight * self.logit_scale, + labels, + _PATCH_OPTS, + **kwargs, + ) + else: + logits = self.lm_head(hidden_states[:, slice_indices, :]) + logits = logits * self.logit_scale # main diff from Llama + + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def patch_cohere( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + global _PATCH_OPTS # pylint: disable=global-statement + from transformers.models.cohere import modeling_cohere + + _PATCH_OPTS = patch_options + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_cohere.CohereForCausalLM + ), f"Expected a CohereForCausalLM model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward, maybe_model) + return maybe_model + + modeling_cohere.CohereForCausalLM.forward = cce_forward + return None + + +def patch_cohere2( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + global _PATCH_OPTS # pylint: disable=global-statement + from transformers.models.cohere2 import modeling_cohere2 + + _PATCH_OPTS = patch_options + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_cohere2.Cohere2ForCausalLM + ), f"Expected a Cohere2ForCausalLM model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward, maybe_model) + return maybe_model + + modeling_cohere2.Cohere2ForCausalLM.forward = cce_forward + return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py new file mode 100644 index 000000000..4c8d2261a --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py @@ -0,0 +1,175 @@ +"""Gemma CCE patch""" + +# This patch is based off transformers 4.50.0. + +# pylint: disable=duplicate-code + +from types import MethodType +from typing import Optional, Tuple, Union + +import torch +import transformers +from cut_cross_entropy.transformers.utils import ( + PatchOptions, + TransformersModelT, + apply_lce, +) +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.gemma.modeling_gemma import ( + _CONFIG_FOR_DOC, + GEMMA_INPUTS_DOCSTRING, + KwargsForCausalLM, +) +from transformers.processing_utils import Unpack +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg + +_PATCH_OPTS: PatchOptions | None = None + + +@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") +@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def cce_forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + loss = None + logits = None + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + + if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): + assert labels is not None + loss = apply_lce( + hidden_states[:, slice_indices, :], + self.lm_head.weight, + labels, + _PATCH_OPTS, + **kwargs, + ) + else: + logits = self.lm_head(hidden_states[:, slice_indices, :]) + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def patch_gemma( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + global _PATCH_OPTS # pylint: disable=global-statement + from transformers.models.gemma import modeling_gemma + + _PATCH_OPTS = patch_options + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_gemma.GemmaForCausalLM + ), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward, maybe_model) + return maybe_model + + modeling_gemma.GemmaForCausalLM.forward = cce_forward + return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py new file mode 100644 index 000000000..ecbe68085 --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py @@ -0,0 +1,465 @@ +"""Gemma2 and Gemma3 (text and multimodal) CCE patch.""" + +# Implementation originally adapted from https://github.com/apple/ml-cross-entropy/pull/29 +# and updated for transformers 4.50.0. +# This is a modified version of the patch that allows for deferred logits calculation for gemma3 and works +# with both gemma3 (text and multimodal) models. + +# pylint: disable=duplicate-code + +from types import MethodType +from typing import Optional, Tuple, Union + +import torch +import transformers +from cut_cross_entropy.transformers.utils import ( + PatchOptions, + TransformersModelT, + apply_lce, +) +from torch import nn +from transformers.cache_utils import Cache, HybridCache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.gemma3.modeling_gemma3 import ( + _CONFIG_FOR_DOC, + GEMMA3_INPUTS_DOCSTRING, + Gemma3CausalLMOutputWithPast, + logger, +) +from transformers.utils import ( + add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg + +_PATCH_OPTS: PatchOptions | None = None + + +@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") +@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def cce_forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + defer_logits_calculation: bool = False, + **loss_kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + defer_logits_calculation (`bool`, *optional*): + If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the + memory overhead of calculating logits using regular lm_head forward pass and to use CCE. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Gemma3ForCausalLM + + >>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **loss_kwargs, + ) + + hidden_states = outputs[0] + loss = None + logits = None + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + + if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): + assert labels is not None + if self.config.final_logit_softcapping is not None: + logger.warning_once( + "final_logit_softcapping is not supported for gemma3_text with CCE. Disabling." + ) + loss = apply_lce( + hidden_states[:, slice_indices, :], + self.lm_head.weight, + labels, + _PATCH_OPTS, + **loss_kwargs, + ) + elif _PATCH_OPTS is not None and defer_logits_calculation: + # defer logits calculation to the ConditionalGeneration forward + logits = hidden_states[:, slice_indices, :] + + if self.config.final_logit_softcapping is not None: + logger.warning_once( + "final_logit_softcapping is not supported for gemma3 with CCE. Disabling." + ) + else: + logits = self.lm_head(hidden_states[:, slice_indices, :]) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") +@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def cce_forward_multimodal( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, +) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration + + >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf") + >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf") + + >>> prompt = "answer en Where is the cow standing?" + >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "answer en Where is the cow standing?\nbeach" + ```""" + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + is_training = token_type_ids is not None and labels is not None + + # Replace image id woth PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.config.image_token_index >= self.vocab_size: + special_image_mask = input_ids == self.config.image_token_index + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids # type: ignore + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 # type: ignore + ) + cache_position = torch.arange( # type: ignore + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor( + self.config.image_token_index, + dtype=torch.long, + device=inputs_embeds.device, + ) + ) + else: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze( + -1 + ) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to( + inputs_embeds.device + ) + + if ( + not is_torchdynamo_compiling() + and inputs_embeds[special_image_mask].numel() != image_features.numel() + ): + image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] + raise ValueError( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " + "tokens from image embeddings." + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore + + # mask out pad-token-ids in labels for BC + if labels is not None and self.pad_token_id in labels: + logger.warning_once( + "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " + "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", + ) + labels = torch.where( # type: ignore + input_ids == self.pad_token_id, self.config.ignore_index, labels + ) + + causal_mask = self._update_causal_mask( # pylint: disable=protected-access + attention_mask, + token_type_ids, + past_key_values, + cache_position, + inputs_embeds, + is_training, + ) + outputs = self.language_model( + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + defer_logits_calculation=True, # enable deferred logits calculation + **lm_kwargs, + ) + + hidden_states = outputs[0] + loss = None + logits = None + + if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): + assert labels is not None + loss = apply_lce( + hidden_states, + self.language_model.lm_head.weight, + labels, + _PATCH_OPTS, + **lm_kwargs, + ) + else: + logits = hidden_states + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to( + logits.device + ) + shift_logits = shift_logits[ + shift_attention_mask.to(logits.device) != 0 + ].contiguous() + shift_labels = shift_labels[ + shift_attention_mask.to(shift_labels.device) != 0 + ].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Gemma3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +def patch_gemma2( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + global _PATCH_OPTS # pylint: disable=global-statement + from transformers.models.gemma2 import modeling_gemma2 + + _PATCH_OPTS = patch_options + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_gemma2.Gemma2ForCausalLM + ), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward, maybe_model) + return maybe_model + + modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward + return None + + +def patch_gemma3_text( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + global _PATCH_OPTS # pylint: disable=global-statement + from transformers.models.gemma3 import modeling_gemma3 + + _PATCH_OPTS = patch_options + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_gemma3.Gemma3ForCausalLM + ), f"Expected a Gemma3ForCausalLM model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward, maybe_model) + return maybe_model + + modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward + return None + + +def patch_gemma3( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + global _PATCH_OPTS # pylint: disable=global-statement + from transformers.models.gemma3 import modeling_gemma3 + + _PATCH_OPTS = patch_options + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_gemma3.Gemma3ForConditionalGeneration + ), f"Expected a Gemma3ForConditionalGeneration model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) + + # patch the causal model to enable deferred logits calculation + maybe_model.language_model.forward = MethodType( + cce_forward, maybe_model.language_model + ) + return maybe_model + + modeling_gemma3.Gemma3ForConditionalGeneration.forward = cce_forward_multimodal + # patch the causal model to enable deferred logits calculation + modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward + return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py new file mode 100644 index 000000000..adb65fa8f --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py @@ -0,0 +1,392 @@ +"""Mistral and Mistral3 CCE patch.""" + +# pylint: disable=duplicate-code + +from types import MethodType +from typing import Optional, Tuple, Union + +import torch +import transformers +from cut_cross_entropy.transformers.utils import ( + PatchOptions, + TransformersModelT, + apply_lce, +) +from torch import nn +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.mistral3.modeling_mistral3 import ( + Mistral3CausalLMOutputWithPast, +) +from transformers.models.mistral.modeling_mistral import ( + _CONFIG_FOR_DOC, + MISTRAL_INPUTS_DOCSTRING, + KwargsForCausalLM, +) +from transformers.processing_utils import Unpack +from transformers.utils import ( + add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg + +_PATCH_OPTS: PatchOptions | None = None + + +@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") +@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def cce_forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: Optional[torch.Tensor] | None = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + defer_logits_calculation: bool = False, + **kwargs: Unpack[KwargsForCausalLM], +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + defer_logits_calculation (`bool`, *optional*): + If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the + memory overhead of calculating logits using regular lm_head forward pass and to use CCE. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + loss = None + logits = None + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + + if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): + assert labels is not None + loss = apply_lce( + hidden_states[:, slice_indices, :], + self.lm_head.weight, + labels, + _PATCH_OPTS, + **kwargs, + ) + elif _PATCH_OPTS is not None and defer_logits_calculation: + # defer logits calculation to the ConditionalGeneration forward + logits = hidden_states[:, slice_indices, :] + else: + logits = self.lm_head(hidden_states[:, slice_indices, :]) + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def cce_forward_multimodal( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: torch.Tensor | None = None, + **lm_kwargs, +) -> Union[Tuple, Mistral3CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration + + >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") + >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") + + >>> prompt = "[INST][IMG]What is the image?[/INST]" + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is the image?The image depicts two cats lying on a pink blanket." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + vision_feature_layer = ( + vision_feature_layer + if vision_feature_layer is not None + else self.config.vision_feature_layer + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + image_sizes=image_sizes, + ) + + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to( + inputs_embeds.device + ) + if ( + not is_torchdynamo_compiling() + and inputs_embeds[special_image_mask].numel() != image_features.numel() + ): + n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + defer_logits_calculation=True, # enable deferred logits calculation + **lm_kwargs, + ) + + hidden_states = outputs[0] + loss = None + logits = None + + if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): + assert labels is not None + loss = apply_lce( + hidden_states, + self.language_model.lm_head.weight, + labels, + _PATCH_OPTS, + **lm_kwargs, + ) + else: + logits = hidden_states + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to( + logits.device + ) + shift_logits = logits[..., :-1, :][ + shift_attention_mask.to(logits.device) != 0 + ].contiguous() + shift_labels = labels[..., 1:][ + shift_attention_mask.to(labels.device) != 0 + ].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1).to(shift_logits.device), + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Mistral3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +def patch_mistral( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + global _PATCH_OPTS # pylint: disable=global-statement + from transformers.models.mistral import modeling_mistral + + _PATCH_OPTS = patch_options + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_mistral.MistralForCausalLM + ), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward, maybe_model) + return maybe_model + + modeling_mistral.MistralForCausalLM.forward = cce_forward + return None + + +def patch_mistral3( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + global _PATCH_OPTS # pylint: disable=global-statement + from transformers.models.mistral import modeling_mistral + from transformers.models.mistral3 import modeling_mistral3 + + _PATCH_OPTS = patch_options + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_mistral3.Mistral3ForConditionalGeneration + ), f"Expected a Mistral3ForConditionalGeneration model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) + + # patch the causal model to enable deferred logits calculation + maybe_model.language_model.forward = MethodType( + cce_forward, maybe_model.language_model + ) + return maybe_model + + modeling_mistral3.Mistral3ForConditionalGeneration.forward = cce_forward_multimodal + # patch the causal model to enable deferred logits calculation + modeling_mistral.MistralForCausalLM.forward = cce_forward + return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py new file mode 100644 index 000000000..850764e10 --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py @@ -0,0 +1,379 @@ +"""Mllama CCE patch.""" + +# pylint: disable=duplicate-code + +from types import MethodType +from typing import Optional, Tuple, Union + +import torch +import transformers +from cut_cross_entropy.transformers.utils import ( + PatchOptions, + TransformersModelT, + apply_lce, +) +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.mllama.modeling_mllama import ( + MLLAMA_INPUTS_DOCSTRING, + _prepare_cross_attention_mask, +) +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg + +_PATCH_OPTS: PatchOptions | None = None + + +@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") +@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig" +) +def cce_forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.LongTensor] = None, + cross_attention_mask: Optional[torch.LongTensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + defer_logits_calculation: bool = False, + **loss_kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + defer_logits_calculation (`bool`, *optional*): + If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the + memory overhead of calculating logits using regular lm_head forward pass and to use CCE. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MllamaForCausalLM + + >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision") + >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") + + >>> prompt = "If I had to write a haiku, it would be:" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) + >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + >>> print(result) + If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. + I love the idea of snowflakes gently falling, each one + ``` + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + cross_attention_states=cross_attention_states, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + loss = None + logits = None + + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + + if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): + assert labels is not None + loss = apply_lce( + hidden_states[:, slice_indices, :], + self.lm_head.weight, + labels, + _PATCH_OPTS, + **loss_kwargs, + ) + elif _PATCH_OPTS is not None and defer_logits_calculation: + # defer logits calculation to the ConditionalGeneration forward + logits = hidden_states[:, slice_indices, :] + else: + logits = self.lm_head(hidden_states[:, slice_indices, :]).float() + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") +@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class="MllamaConfig" +) +def cce_forward_multimodal( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, MllamaForConditionalGeneration + + >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" + >>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint) + >>> processor = AutoProcessor.from_pretrained(checkpoint) + + >>> prompt = "<|image|>If I had to write a haiku for this one" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + + >>> # Generate + >>> output = model.generate(**inputs, max_new_tokens=15) + + >>> prompt_len = inputs.input_ids.shape[-1] + >>> generated_ids = output[:, prompt_len:] + >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + >>> print(generated_text) + [', it would be:.\\nA stop sign in Chinatown.\\n'] + ``` + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and cross_attention_states is not None: + raise ValueError( + "`pixel_values` and `cross_attention_states` cannot be provided simultaneously" + ) + + if pixel_values is not None: + if aspect_ratio_ids is None: + raise ValueError( + "`aspect_ratio_ids` must be provided if `pixel_values` is provided" + ) + # get vision tokens from vision model + vision_outputs = self.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + cross_attention_states = vision_outputs[0] + cross_attention_states = self.multi_modal_projector( + cross_attention_states + ).reshape( + -1, cross_attention_states.shape[-2], self.hidden_size # type: ignore + ) + + if cross_attention_mask is not None: + cross_attention_mask, full_text_row_masked_out_mask = ( + _prepare_cross_attention_mask( + cross_attention_mask, + num_vision_tokens=self.vision_model.num_patches, + dtype=self.dtype, + ) + ) + else: + full_text_row_masked_out_mask = None + + if cross_attention_mask is not None and cache_position is not None: + cross_attention_mask = cross_attention_mask[:, :, cache_position] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[ + :, :, cache_position + ] + + outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + use_cache=use_cache, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + defer_logits_calculation=True, # enable deferred logits calculation + **loss_kwargs, + ) + + hidden_states = outputs[0] + loss = None + logits = None + + if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): + assert labels is not None + loss = apply_lce( + hidden_states, + self.language_model.lm_head.weight, + labels, + _PATCH_OPTS, + **loss_kwargs, + ) + else: + # Temporary fix to calculate the loss in main class, as the model's vocab size may be resized + logits = hidden_states + + if labels is not None: + loss = self.loss_function( + logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs + ) + + if not return_dict: + return (loss,) + outputs if loss is not None else outputs + + return CausalLMOutputWithPast( + loss=loss, + logits=outputs.logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def patch_mllama( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + + global _PATCH_OPTS # pylint: disable=global-statement + from transformers.models.mllama import modeling_mllama + + _PATCH_OPTS = patch_options + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_mllama.MllamaForConditionalGeneration + ), f"Expected a MllamaForConditionalGeneration model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) + + # patch the language model + maybe_model.language_model.forward = MethodType( + cce_forward, maybe_model.language_model + ) + return maybe_model + + modeling_mllama.MllamaForConditionalGeneration.forward = cce_forward_multimodal + + # patch the causal language model + modeling_mllama.MllamaForCausalLM.forward = cce_forward + return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py new file mode 100644 index 000000000..b9c83ff02 --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py @@ -0,0 +1,85 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. + +"""Cut Cross Entropy patcher""" + +import transformers +from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl +from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT +from cut_cross_entropy.transformers.llama import patch_llama +from cut_cross_entropy.transformers.phi3 import patch_phi3 +from cut_cross_entropy.transformers.qwen2 import patch_qwen2 +from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT + +from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import ( + patch_cohere, + patch_cohere2, +) +from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma +from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import ( + patch_gemma2, + patch_gemma3, + patch_gemma3_text, +) +from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import ( + patch_mistral, + patch_mistral3, +) +from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama + +CUT_CROSS_ENTROPY_MODEL_MAPPING = { + "llama": patch_llama, + "mllama": patch_mllama, + "phi3": patch_phi3, + "gemma": patch_gemma, + "gemma2": patch_gemma2, + "gemma3": patch_gemma3, + "gemma3_text": patch_gemma3_text, + "mistral": patch_mistral, + "mistral3": patch_mistral3, + "qwen2": patch_qwen2, + "cohere": patch_cohere, + "cohere2": patch_cohere2, +} + + +def cce_patch( + model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig, + impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT, + reduction: str = "mean", + filter_eps: float | str | None = "auto", + accum_e_fp32: bool = False, + accum_c_fp32: bool = False, + filter_e_grad: bool = True, + filter_c_grad: bool = True, + train_only: bool = False, +) -> TransformersModelT | None: + if isinstance(impl, LinearCrossEntropyImpl): + impl = impl.name.lower() + + if impl not in (v.name.lower() for v in LinearCrossEntropyImpl): + raise ValueError(f"Unknown {impl=}") + + if isinstance(model_type_or_model, transformers.PreTrainedModel): + model_type = model_type_or_model.config.model_type + elif isinstance(model_type_or_model, transformers.PretrainedConfig): + model_type = model_type_or_model.model_type + else: + model_type = model_type_or_model + + patch_options = PatchOptions( + impl=impl, + reduction=reduction, + filter_eps=filter_eps, + accum_e_fp32=accum_e_fp32, + accum_c_fp32=accum_c_fp32, + filter_e_grad=filter_e_grad, + filter_c_grad=filter_c_grad, + train_only=train_only, + ) + + if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING: + return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type]( + model_type_or_model, patch_options + ) + + raise RuntimeError(f"Unknown model type {model_type}") diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index d6d209db5..cd819fba4 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -23,6 +23,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "gemma", "gemma2", "gemma3_text", + "cohere", + "cohere2", "gemmoe", "starcoder2", "deepseek_v2", diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 9ccd2ca0c..7dbdd0b76 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -314,6 +314,7 @@ def save_initial_configs( tokenizer: PreTrainedTokenizer, model: PreTrainedModel, peft_config: PeftConfig | None, + processor: ProcessorMixin | None, ): """ Save initial configurations before training. @@ -341,6 +342,10 @@ def save_initial_configs( LOG.info(f"Pre-saving model config to {cfg.output_dir}...") model.config.save_pretrained(str(output_dir)) + if processor: + LOG.info(f"Pre-saving processor to {cfg.output_dir}...") + processor.save_pretrained(str(output_dir)) + def setup_model_card(cfg: DictDefault): """ @@ -408,6 +413,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> PeftModel | PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, + ProcessorMixin | None, ]: """ Load model, tokenizer, trainer, etc. Helper function to encapsulate the full @@ -423,6 +429,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> - Model - Tokenizer - PEFT config + - Processor """ # Load tokenizer, processor and model model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg) @@ -453,6 +460,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> model, tokenizer, peft_config, + processor, ) @@ -475,6 +483,7 @@ def train( model, tokenizer, peft_config, + processor, ) = setup_model_and_trainer(cfg, dataset_meta) # Determine if we need to resume from a checkpoint @@ -490,7 +499,7 @@ def train( ) # Save initial configs - save_initial_configs(cfg, tokenizer, model, peft_config) + save_initial_configs(cfg, tokenizer, model, peft_config, processor) # Set up signal handler for graceful termination setup_signal_handler(cfg, model, safe_serialization) diff --git a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py index ce7a2bf0f..69c84f445 100644 --- a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py +++ b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py @@ -408,7 +408,7 @@ def test_kernel_training_integration(): ) # Load model - model, _ = load_model_and_tokenizer(cfg=cfg) + model, _, _ = load_model_and_tokenizer(cfg=cfg) # Verify correct activation function layer = model.model.model.layers[0]