diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 195746d2d..a029ba39f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: hooks: - id: isort - repo: https://github.com/PyCQA/flake8 - rev: 7.2.0 + rev: 7.3.0 hooks: - id: flake8 - repo: https://github.com/pylint-dev/pylint @@ -27,7 +27,7 @@ repos: hooks: - id: pylint - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.16.0 + rev: v1.16.1 hooks: - id: mypy additional_dependencies: @@ -36,7 +36,7 @@ repos: 'pydantic>=2.5.3', ] - repo: https://github.com/PyCQA/bandit - rev: 1.8.3 + rev: 1.8.5 hooks: - id: bandit args: [ diff --git a/README.md b/README.md index 3bfce8df1..e18220567 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ Features: - **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models. - **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM). - **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference. -- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), Sequence Parallelism (SP), LoRA optimizations, Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more! +- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more! - **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets. - **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware. diff --git a/docs/dataset-formats/conversation.qmd b/docs/dataset-formats/conversation.qmd index d1fca9441..90cc178fc 100644 --- a/docs/dataset-formats/conversation.qmd +++ b/docs/dataset-formats/conversation.qmd @@ -9,7 +9,7 @@ order: 3 Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2. ```{.json filename="data.jsonl"} -{"conversations": [{"role": "...", "content": "..."}]} +{"messages": [{"role": "...", "content": "..."}, {"role": "...", "content": "..."}, ...]} ``` See [configs](../config-reference.qmd) for full configs and supported templates. diff --git a/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml new file mode 100644 index 000000000..1dd901154 --- /dev/null +++ b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml @@ -0,0 +1,71 @@ +base_model: tiiuae/Falcon-H1-1.5B-Deep-Base +# optionally might have model_type or tokenizer_type +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: true + +# huggingface repo +chat_template: falcon_h1 +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - in_proj + - gate_proj + - up_proj + - down_proj + +sequence_len: 2048 +sample_packing: false +eval_sample_packing: false +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/examples/falcon-h1/falcon-h1-1b-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-qlora.yaml new file mode 100644 index 000000000..24dc7cae3 --- /dev/null +++ b/examples/falcon-h1/falcon-h1-1b-qlora.yaml @@ -0,0 +1,71 @@ +base_model: tiiuae/Falcon-H1-1.5B-Base +# optionally might have model_type or tokenizer_type +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: true + +# huggingface repo +chat_template: falcon_h1 +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - in_proj + - gate_proj + - up_proj + - down_proj + +sequence_len: 2048 +sample_packing: false +eval_sample_packing: false +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/examples/falcon-h1/falcon-h1-34b-qlora.yaml b/examples/falcon-h1/falcon-h1-34b-qlora.yaml new file mode 100644 index 000000000..43eb1967b --- /dev/null +++ b/examples/falcon-h1/falcon-h1-34b-qlora.yaml @@ -0,0 +1,71 @@ +base_model: tiiuae/Falcon-H1-34B-Base +# optionally might have model_type or tokenizer_type +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: true + +# huggingface repo +chat_template: falcon_h1 +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - in_proj + - gate_proj + - up_proj + - down_proj + +sequence_len: 2048 +sample_packing: false +eval_sample_packing: false +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/examples/falcon-h1/falcon-h1-3b-qlora.yaml b/examples/falcon-h1/falcon-h1-3b-qlora.yaml new file mode 100644 index 000000000..00929bbf0 --- /dev/null +++ b/examples/falcon-h1/falcon-h1-3b-qlora.yaml @@ -0,0 +1,71 @@ +base_model: tiiuae/Falcon-H1-3B-Base +# optionally might have model_type or tokenizer_type +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: true + +# huggingface repo +chat_template: falcon_h1 +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - in_proj + - gate_proj + - up_proj + - down_proj + +sequence_len: 2048 +sample_packing: false +eval_sample_packing: false +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/examples/falcon-h1/falcon-h1-500m-qlora.yaml b/examples/falcon-h1/falcon-h1-500m-qlora.yaml new file mode 100644 index 000000000..e2640de7b --- /dev/null +++ b/examples/falcon-h1/falcon-h1-500m-qlora.yaml @@ -0,0 +1,71 @@ +base_model: tiiuae/Falcon-H1-0.5B-Instruct +# optionally might have model_type or tokenizer_type +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: true + +# huggingface repo +chat_template: falcon_h1 +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - in_proj + - gate_proj + - up_proj + - down_proj + +sequence_len: 2048 +sample_packing: false +eval_sample_packing: false +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/examples/falcon-h1/falcon-h1-7b-qlora.yaml b/examples/falcon-h1/falcon-h1-7b-qlora.yaml new file mode 100644 index 000000000..183e423b5 --- /dev/null +++ b/examples/falcon-h1/falcon-h1-7b-qlora.yaml @@ -0,0 +1,71 @@ +base_model: tiiuae/Falcon-H1-7B-Base +# optionally might have model_type or tokenizer_type +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: true + +# huggingface repo +chat_template: falcon_h1 +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - in_proj + - gate_proj + - up_proj + - down_proj + +sequence_len: 2048 +sample_packing: false +eval_sample_packing: false +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/examples/gemma3/gemma-3-1b-qlora.yml b/examples/gemma3/gemma-3-1b-qlora.yml index 44310558c..217c887aa 100644 --- a/examples/gemma3/gemma-3-1b-qlora.yml +++ b/examples/gemma3/gemma-3-1b-qlora.yml @@ -13,6 +13,8 @@ load_in_4bit: true # huggingface repo chat_template: gemma3 +eot_tokens: + - datasets: - path: cgato/SlimOrcaDedupCleaned type: chat_template diff --git a/examples/gemma3/gemma-3-4b-qlora.yml b/examples/gemma3/gemma-3-4b-qlora.yml index 0d89d9ffb..d78559ae3 100644 --- a/examples/gemma3/gemma-3-4b-qlora.yml +++ b/examples/gemma3/gemma-3-4b-qlora.yml @@ -6,6 +6,8 @@ load_in_4bit: true ddp_find_unused_parameters: true chat_template: gemma3 +eot_tokens: + - datasets: - path: cgato/SlimOrcaDedupCleaned type: chat_template diff --git a/examples/gemma3/gemma-3-4b-vision-qlora.yml b/examples/gemma3/gemma-3-4b-vision-qlora.yml index 339df92e5..183eb88e8 100644 --- a/examples/gemma3/gemma-3-4b-vision-qlora.yml +++ b/examples/gemma3/gemma-3-4b-vision-qlora.yml @@ -12,6 +12,8 @@ sample_packing: false ddp_find_unused_parameters: true chat_template: gemma3 +eot_tokens: + - datasets: - path: HuggingFaceH4/llava-instruct-mix-vsft type: chat_template diff --git a/examples/qwen2_5-vl/lora-7b.yaml b/examples/qwen2_5-vl/lora-7b.yaml new file mode 100644 index 000000000..25d02805f --- /dev/null +++ b/examples/qwen2_5-vl/lora-7b.yaml @@ -0,0 +1,55 @@ +base_model: Qwen/Qwen2.5-VL-7B-Instruct +processor_type: AutoProcessor + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +chat_template: qwen2_vl +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index 4a92746c1..bb9224bb0 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else "" print( UNINSTALL_PREFIX - + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a1174ca"' + + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@78b2a45713a54c9bedf8b33f5e31cf07a1a57154"' ) diff --git a/src/axolotl/cli/checks.py b/src/axolotl/cli/checks.py index 10086c2a4..a743e74dc 100644 --- a/src/axolotl/cli/checks.py +++ b/src/axolotl/cli/checks.py @@ -6,6 +6,7 @@ from pathlib import Path from accelerate.commands.config import config_args from huggingface_hub import HfApi from huggingface_hub.utils import LocalTokenNotFoundError +from requests import HTTPError from axolotl.utils.logging import get_logger @@ -46,3 +47,8 @@ def check_user_token() -> bool: "Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets." ) return False + except HTTPError: + LOG.warning( + "Error accessing HuggingFace. This may be due to a network issue or rate limiting." + ) + return False diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index fbae253d6..b0e6e8eae 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -20,7 +20,7 @@ from torch.utils.data import ( SequentialSampler, ) from transformers import Trainer -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker from trl.trainer.utils import pad_to_length from typing_extensions import override @@ -116,14 +116,15 @@ class AxolotlTrainer( sequential=self.args.sample_packing_sequentially, drop_last=True, num_processes=self.args.dataset_num_proc, + mp_start_method=self.args.sample_packing_mp_start_method or "fork", ) len(sampler) return sampler def _get_train_sampler( - self, train_dataset: Optional[Dataset] = None - ) -> Optional[Sampler]: + self, train_dataset: Dataset | None = None + ) -> Sampler | None: """ Helper method to get the sampler for training. Handles cases for sample packing and curriculum sampling (sequential). @@ -132,16 +133,22 @@ class AxolotlTrainer( If the dataset is non-empty, a sampler is returned, the type of which depends on the passed training args. """ + # from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L969C1-L972C24 + if train_dataset is None: + train_dataset = self.train_dataset + if train_dataset is None or not has_length(train_dataset): + return None + use_sample_packing = self.args.sample_packing and not self.args.pretraining # Determine the base sampler first if self.args.curriculum_sampling: - base_sampler = SequentialSampler(self.train_dataset) + base_sampler = SequentialSampler(train_dataset) elif use_sample_packing: - base_sampler = RandomSampler(self.train_dataset) + base_sampler = RandomSampler(train_dataset) else: # Default to parent class implementation for standard random sampling - return super()._get_train_sampler() + return super()._get_train_sampler(train_dataset) # Apply multipack wrapper if needed if use_sample_packing: @@ -160,6 +167,10 @@ class AxolotlTrainer( If the dataset is non-empty, a sampler is returned, the type of which depends on the passed training args. """ + # from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L1065C9-L1066C24 + if eval_dataset is None or not has_length(eval_dataset): + return None + # Multipacking enabled if training is enabled and eval is not explicitly disabled use_multipack = ( self.args.sample_packing and self.args.eval_sample_packing is not False diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py index 8fcaff632..e04be43e0 100644 --- a/src/axolotl/core/training_args_base.py +++ b/src/axolotl/core/training_args_base.py @@ -38,6 +38,10 @@ class AxolotlTrainingMixins: "help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing." }, ) + sample_packing_mp_start_method: str | None = field( + default=None, + metadata={"help": "The multiprocessing start method to use."}, + ) multipack_real_batches: bool = field( default=False, metadata={"help": "Use real batches for efficient training."}, diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index bddf3ced2..b5e3ecda8 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -19,19 +19,11 @@ python scripts/cutcrossentropy_install.py | sh - If you are installing from pip ```bash -pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438" +pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@78b2a45713a54c9bedf8b33f5e31cf07a1a57154" ``` ## Usage -**NOTE**: If you are training a VLM model, please use older version of Axolotl as upstream has applied a major VLM refactor, and our patches have not been updated yet. - -```bash -git checkout 787880215b3ab32ccaf81c1b2e9588c6f3e6e764 - -pip3 install --no-build-isolation -e . -``` - ```yaml plugins: - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin @@ -39,27 +31,29 @@ plugins: ## Supported Models -- llama -- llama4 -- llama4_text -- mllama -- phi3 +- cohere +- cohere2 - gemma - gemma2 - gemma3 - gemma3_text +- glm +- glm4 +- llama +- llama4 +- llama4_text - mistral - mistral3 +- mllama +- phi +- phi3 +- phi4_multimodal - qwen2 -- qwen2_moe - qwen2_vl +- qwen2_moe - qwen2_5_vl - qwen3 - qwen3_moe -- cohere -- cohere2 -- glm -- glm4 ## Citation diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index c29bb55d4..37f4dba68 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -31,8 +31,8 @@ from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: LOG = get_logger(__name__) _CCE_INSTALL_MESSAGE = ( - "Please install cut_cross_entropy with transformers support using " - '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"`' + "Please install Axolotl's fork of cut_cross_entropy with transformers support using " + '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@7f6afce"`' ) @@ -64,16 +64,28 @@ class CutCrossEntropyPlugin(BasePlugin): "cut_cross_entropy.transformers" ) if cce_spec_transformers is None: - raise ImportError(_CCE_INSTALL_MESSAGE) + raise ImportError( + "Transformers support is not installed. " + _CCE_INSTALL_MESSAGE + ) + + # Check if Axolotl's cce fork is installed + try: + from cut_cross_entropy.transformers.patch import AXOLOTL_CCE_FORK + + if not AXOLOTL_CCE_FORK: + raise ImportError + except ImportError as e: + raise ImportError( + "Axolotl's fork of cut_cross_entropy is not installed. " + + _CCE_INSTALL_MESSAGE + ) from e def pre_model_load(self, cfg): """Apply cut cross entropy before model loading if enabled.""" if cfg.cut_cross_entropy: self._check_requirements() - from .monkeypatch.patch import ( - cce_patch, - ) + from cut_cross_entropy.transformers.patch import cce_patch LOG.info( f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}" diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/__init__.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py deleted file mode 100644 index ea9e10724..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py +++ /dev/null @@ -1,191 +0,0 @@ -"""Cohere and Cohere2 CCE patch.""" - -# This patch is based off transformers 4.50.0. -# It patches the forward function for CohereForCausalLM and Cohere2ForCausalLM. -# It scales the hidden states by the logit scale in advance instead of the logits as the -# operation is done internally and should be mathematically equivalent. - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.cohere.modeling_cohere import ( - KwargsForCausalLM, -) -from transformers.processing_utils import Unpack -from transformers.utils.deprecation import deprecate_kwarg - -_PATCH_OPTS: PatchOptions | None = None - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >> from transformers import AutoTokenizer, CohereForCausalLM - - >> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01") - >> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01") - - >> prompt = "Hey, are you conscious? Can you talk to me?" - >> inputs = tokenizer(prompt, return_tensors="pt") - - >> # Generate - >> generate_ids = model.generate(inputs.input_ids, max_length=30) - >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - # scale hidden_states by logit_scale in-place of logits - loss = apply_lce( - hidden_states[:, slice_indices, :] * self.logit_scale, - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - logits = logits * self.logit_scale # main diff from Llama - - if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def patch_cohere( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.cohere import modeling_cohere - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_cohere.CohereForCausalLM - ), f"Expected a CohereForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_cohere.CohereForCausalLM.forward = cce_forward - return None - - -def patch_cohere2( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.cohere2 import modeling_cohere2 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_cohere2.Cohere2ForCausalLM - ), f"Expected a Cohere2ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_cohere2.Cohere2ForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py deleted file mode 100644 index ae3d8c6ef..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py +++ /dev/null @@ -1,165 +0,0 @@ -"""Gemma CCE patch""" - -# This patch is based off transformers 4.50.0. - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.gemma.modeling_gemma import ( - KwargsForCausalLM, -) -from transformers.processing_utils import Unpack -from transformers.utils.deprecation import deprecate_kwarg - -_PATCH_OPTS: PatchOptions | None = None - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, GemmaForCausalLM - - >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") - - >>> prompt = "What is your favorite condiment?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is your favorite condiment?" - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def patch_gemma( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.gemma import modeling_gemma - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_gemma.GemmaForCausalLM - ), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_gemma.GemmaForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py deleted file mode 100644 index 644e5cce7..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py +++ /dev/null @@ -1,447 +0,0 @@ -"""Gemma2 and Gemma3 (text and multimodal) CCE patch.""" - -# Implementation originally adapted from https://github.com/apple/ml-cross-entropy/pull/29 -# and updated for transformers 4.50.0. -# This is a modified version of the patch that allows for deferred logits calculation for gemma3 and works -# with both gemma3 (text and multimodal) models. - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, -) -from torch import nn -from transformers.cache_utils import Cache, HybridCache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.gemma3.modeling_gemma3 import ( - Gemma3CausalLMOutputWithPast, - logger, -) -from transformers.utils import ( - is_torchdynamo_compiling, -) -from transformers.utils.deprecation import deprecate_kwarg - -from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce - -_PATCH_OPTS: PatchOptions | None = None - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - defer_logits_calculation: bool = False, - **loss_kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - defer_logits_calculation (`bool`, *optional*): - If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the - memory overhead of calculating logits using regular lm_head forward pass and to use CCE. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Gemma3ForCausalLM - - >>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") - - >>> prompt = "What is your favorite condiment?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is your favorite condiment?" - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **loss_kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - softcap=getattr(self.config, "final_logit_softcapping", None), - **loss_kwargs, - ) - elif _PATCH_OPTS is not None and defer_logits_calculation: - # defer logits calculation to the ConditionalGeneration forward - logits = hidden_states[:, slice_indices, :] - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - if self.config.final_logit_softcapping is not None: - logits = logits / self.config.final_logit_softcapping - logits = torch.tanh(logits) - logits = logits * self.config.final_logit_softcapping - - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward_multimodal( - self, - input_ids: torch.LongTensor | None = None, - pixel_values: torch.FloatTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, - token_type_ids: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **lm_kwargs, -) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration - - >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf") - >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf") - - >>> prompt = "answer en Where is the cow standing?" - >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_length=30) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "answer en Where is the cow standing?\nbeach" - ```""" - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - is_training = token_type_ids is not None and labels is not None - - # Replace image id woth PAD if the image token if OOV, to avoid index-errors - if input_ids is not None and self.config.image_token_index >= self.vocab_size: - special_image_mask = input_ids == self.config.image_token_index - llm_input_ids = input_ids.clone() - llm_input_ids[special_image_mask] = 0 - else: - llm_input_ids = input_ids # type: ignore - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(llm_input_ids) - - if cache_position is None: - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 # type: ignore - ) - cache_position = torch.arange( # type: ignore - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ) - - # Merge text and images - if pixel_values is not None: - image_features = self.get_image_features(pixel_values) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor( - self.config.image_token_index, - dtype=torch.long, - device=inputs_embeds.device, - ) - ) - else: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze( - -1 - ) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to( - inputs_embeds.device - ) - - if ( - not is_torchdynamo_compiling() - and inputs_embeds[special_image_mask].numel() != image_features.numel() - ): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " - "tokens from image embeddings." - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore - - # mask out pad-token-ids in labels for BC - if labels is not None and self.pad_token_id in labels: - logger.warning_once( - "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " - "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", - ) - labels = torch.where( # type: ignore - input_ids == self.pad_token_id, self.config.ignore_index, labels - ) - - causal_mask = self._update_causal_mask( # pylint: disable=protected-access - attention_mask, - token_type_ids, - past_key_values, - cache_position, - inputs_embeds, - is_training, - ) - outputs = self.language_model( - attention_mask=causal_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - logits_to_keep=logits_to_keep, - defer_logits_calculation=True, # enable deferred logits calculation - **lm_kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states, - self.language_model.lm_head.weight, - labels, - _PATCH_OPTS, - softcap=getattr(self.config, "final_logit_softcapping", None), - **lm_kwargs, - ) - else: - logits = hidden_states - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - shift_logits = logits[..., :-1, :] - shift_labels = labels[..., 1:] - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to( - logits.device - ) - shift_logits = shift_logits[ - shift_attention_mask.to(logits.device) != 0 - ].contiguous() - shift_labels = shift_labels[ - shift_attention_mask.to(shift_labels.device) != 0 - ].contiguous() - else: - shift_logits = shift_logits.contiguous() - shift_labels = shift_labels.contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - - flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) - flat_labels = shift_labels.view(-1).to(shift_logits.device) - loss = loss_fct(flat_logits, flat_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Gemma3CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) - - -def patch_gemma2( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.gemma2 import modeling_gemma2 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_gemma2.Gemma2ForCausalLM - ), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward - return None - - -def patch_gemma3_text( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.gemma3 import modeling_gemma3 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_gemma3.Gemma3ForCausalLM - ), f"Expected a Gemma3ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward - return None - - -def patch_gemma3( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.gemma3 import modeling_gemma3 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_gemma3.Gemma3ForConditionalGeneration - ), f"Expected a Gemma3ForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - # patch the causal model to enable deferred logits calculation - maybe_model.language_model.forward = MethodType( - cce_forward, maybe_model.language_model - ) - return maybe_model - - modeling_gemma3.Gemma3ForConditionalGeneration.forward = cce_forward_multimodal - # patch the causal model to enable deferred logits calculation - modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/glm4.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/glm4.py deleted file mode 100644 index 3df909f88..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/glm4.py +++ /dev/null @@ -1,57 +0,0 @@ -"""GLM 4 patch. GLM family inherits from Llama.""" - -from types import MethodType - -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, -) - - -def patch_glm( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - - # Set the _PATCH_OPTS in the llama patch file - import cut_cross_entropy.transformers.llama as llama_patch - - llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access - - from cut_cross_entropy.transformers.llama import cce_forward - from transformers.models.glm import modeling_glm - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_glm.GlmForCausalLM - ), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_glm.GlmForCausalLM.forward = cce_forward - return None - - -def patch_glm4( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - - # Set the _PATCH_OPTS in the llama patch file - import cut_cross_entropy.transformers.llama as llama_patch - - llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access - - from cut_cross_entropy.transformers.llama import cce_forward - from transformers.models.glm4 import modeling_glm4 - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_glm4.Glm4ForCausalLM - ), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_glm4.Glm4ForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py deleted file mode 100644 index bed411ace..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py +++ /dev/null @@ -1,164 +0,0 @@ -"""Llama CCE patch. Adapted from transformers v4.51.2""" - -# pylint: disable=duplicate-code - - -from types import MethodType -from typing import Optional, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.cache_utils import Cache -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.models.llama.modeling_llama import ( - KwargsForCausalLM, -) -from transformers.processing_utils import Unpack -from transformers.utils.deprecation import deprecate_kwarg -from transformers.utils.generic import can_return_tuple - -_PATCH_OPTS: PatchOptions | None = None - - -@can_return_tuple -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], -) -> CausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - if hidden_states is None: - raise ValueError("hidden_states is None") - - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def patch_llama( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - """Patch Llama for CCE.""" - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.llama import modeling_llama - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_llama.LlamaForCausalLM - ), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_llama.LlamaForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py deleted file mode 100644 index 3143e9c8d..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py +++ /dev/null @@ -1,401 +0,0 @@ -"""Llama4 CCE patch. Adapted from transformers 4.51.0.""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from torch import nn -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.llama4.modeling_llama4 import ( - Llama4CausalLMOutputWithPast, -) - -_PATCH_OPTS: PatchOptions | None = None - - -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - defer_logits_calculation: bool = False, - **kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - defer_logits_calculation (`bool`, *optional*, defaults to `False`): - If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the - memory overhead of calculating logits using regular lm_head forward pass and to use CCE. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Llama4ForCausalLM - - >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - elif _PATCH_OPTS is not None and defer_logits_calculation: - # defer logits calculation to the ConditionalGeneration forward - logits = hidden_states[:, slice_indices, :] - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def cce_forward_multimodal( - self, - input_ids: torch.LongTensor | None = None, # type: ignore - pixel_values: torch.FloatTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[Union[int, list[int]]] = None, - vision_feature_select_strategy: Optional[str] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - image_sizes: torch.Tensor | None = None, - **lm_kwargs, -) -> Union[Tuple, Llama4CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, LlavaForConditionalGeneration - - >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") - >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") - - >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_new_tokens=15) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - vision_feature_layer = ( - vision_feature_layer - if vision_feature_layer is not None - else self.config.vision_config.vision_feature_layer - ) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_config.vision_feature_select_strategy - ) - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) # type: ignore - - if pixel_values is not None: - image_features = self.get_image_features( - pixel_values=pixel_values, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - image_sizes=image_sizes, - ) - original_inputs_embeds_shape = inputs_embeds.shape # type: ignore - - vision_flat = image_features.view(-1, image_features.size(-1)) - projected_vision_flat = self.multi_modal_projector(vision_flat) - - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - final_mask = special_image_mask.to(inputs_embeds.device) # type: ignore - inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore - - final_mask_1d = final_mask[..., 0].reshape(-1) - num_tokens_to_fill = final_mask_1d.sum() - - if num_tokens_to_fill != projected_vision_flat.size(0): - raise ValueError( - f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, " - f"but multi_modal_projector returned {projected_vision_flat.size(0)}" - ) - - expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1)) - inputs_embeds = inputs_embeds.masked_scatter( - expanded_mask, projected_vision_flat - ) # type: ignore - inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) # type: ignore - - outputs = self.language_model( - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - logits_to_keep=logits_to_keep, - defer_logits_calculation=True, # enable deferred logits calculation - **lm_kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - # TODO: check if need to handle attention_mask - loss = apply_lce( - hidden_states, - self.language_model.lm_head.weight, - labels, - _PATCH_OPTS, - **lm_kwargs, - ) - else: - logits = hidden_states - if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to( - logits.device - ) - shift_logits = logits[..., :-1, :][ - shift_attention_mask.to(logits.device) != 0 - ].contiguous() - shift_labels = labels[..., 1:][ - shift_attention_mask.to(labels.device) != 0 - ].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1).to(shift_logits.device), - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Llama4CausalLMOutputWithPast( - loss=loss, - logits=logits, # type: ignore # TODO: check if need to create dummy logits - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) - - -def patch_llama4_text( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.llama4 import modeling_llama4 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_llama4.Llama4ForCausalLM - ), f"Expected a Llama4ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - - return maybe_model - - setattr( - modeling_llama4.Llama4ForCausalLM, - "forward", - cce_forward, - ) - return None - - -def patch_llama4( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.llama4 import modeling_llama4 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_llama4.Llama4ForConditionalGeneration - ), f"Expected a Llama4ForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - # patch the language model - maybe_model.language_model.forward = MethodType( - cce_forward, maybe_model.language_model - ) - return maybe_model - - setattr( - modeling_llama4.Llama4ForConditionalGeneration, - "forward", - cce_forward_multimodal, - ) - - # patch the causal language model - setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward) - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py deleted file mode 100644 index aa252701e..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py +++ /dev/null @@ -1,384 +0,0 @@ -"""Mistral and Mistral3 CCE patch.""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from torch import nn -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.mistral3.modeling_mistral3 import ( - Mistral3CausalLMOutputWithPast, -) -from transformers.models.mistral.modeling_mistral import ( - KwargsForCausalLM, -) -from transformers.processing_utils import Unpack -from transformers.utils import ( - is_torchdynamo_compiling, -) -from transformers.utils.deprecation import deprecate_kwarg - -_PATCH_OPTS: PatchOptions | None = None - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] | None = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - defer_logits_calculation: bool = False, - **kwargs: Unpack[KwargsForCausalLM], -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - defer_logits_calculation (`bool`, *optional*): - If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the - memory overhead of calculating logits using regular lm_head forward pass and to use CCE. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, MistralForCausalLM - - >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - elif _PATCH_OPTS is not None and defer_logits_calculation: - # defer logits calculation to the ConditionalGeneration forward - logits = hidden_states[:, slice_indices, :] - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def cce_forward_multimodal( - self, - input_ids: torch.LongTensor | None = None, - pixel_values: torch.FloatTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[Union[int, list[int]]] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - image_sizes: torch.Tensor | None = None, - **lm_kwargs, -) -> Union[Tuple, Mistral3CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration - - >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") - >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") - - >>> prompt = "[INST][IMG]What is the image?[/INST]" - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_new_tokens=15) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is the image?The image depicts two cats lying on a pink blanket." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - vision_feature_layer = ( - vision_feature_layer - if vision_feature_layer is not None - else self.config.vision_feature_layer - ) - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if pixel_values is not None: - image_features = self.get_image_features( - pixel_values=pixel_values, - vision_feature_layer=vision_feature_layer, - image_sizes=image_sizes, - ) - - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to( - inputs_embeds.device - ) - if ( - not is_torchdynamo_compiling() - and inputs_embeds[special_image_mask].numel() != image_features.numel() - ): - n_image_tokens = (input_ids == self.config.image_token_index).sum() - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore - - outputs = self.language_model( - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - logits_to_keep=logits_to_keep, - defer_logits_calculation=True, # enable deferred logits calculation - **lm_kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states, - self.language_model.lm_head.weight, - labels, - _PATCH_OPTS, - **lm_kwargs, - ) - else: - logits = hidden_states - if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to( - logits.device - ) - shift_logits = logits[..., :-1, :][ - shift_attention_mask.to(logits.device) != 0 - ].contiguous() - shift_labels = labels[..., 1:][ - shift_attention_mask.to(labels.device) != 0 - ].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1).to(shift_logits.device), - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Mistral3CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) - - -def patch_mistral( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.mistral import modeling_mistral - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_mistral.MistralForCausalLM - ), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_mistral.MistralForCausalLM.forward = cce_forward - return None - - -def patch_mistral3( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.mistral import modeling_mistral - from transformers.models.mistral3 import modeling_mistral3 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_mistral3.Mistral3ForConditionalGeneration - ), f"Expected a Mistral3ForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - # patch the causal model to enable deferred logits calculation - maybe_model.language_model.forward = MethodType( - cce_forward, maybe_model.language_model - ) - return maybe_model - - modeling_mistral3.Mistral3ForConditionalGeneration.forward = cce_forward_multimodal - # patch the causal model to enable deferred logits calculation - modeling_mistral.MistralForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py deleted file mode 100644 index e82853e6c..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py +++ /dev/null @@ -1,366 +0,0 @@ -"""Mllama CCE patch.""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.mllama.modeling_mllama import ( - _prepare_cross_attention_mask, -) -from transformers.utils.deprecation import deprecate_kwarg - -_PATCH_OPTS: PatchOptions | None = None - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.LongTensor] = None, - cross_attention_mask: Optional[torch.LongTensor] = None, - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - defer_logits_calculation: bool = False, - **loss_kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - defer_logits_calculation (`bool`, *optional*): - If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the - memory overhead of calculating logits using regular lm_head forward pass and to use CCE. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, MllamaForCausalLM - - >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision") - >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") - - >>> prompt = "If I had to write a haiku, it would be:" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) - >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - >>> print(result) - If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. - I love the idea of snowflakes gently falling, each one - ``` - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - cross_attention_states=cross_attention_states, - attention_mask=attention_mask, - position_ids=position_ids, - cross_attention_mask=cross_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **loss_kwargs, - ) - elif _PATCH_OPTS is not None and defer_logits_calculation: - # defer logits calculation to the ConditionalGeneration forward - logits = hidden_states[:, slice_indices, :] - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]).float() - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward_multimodal( - self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - aspect_ratio_mask: Optional[torch.Tensor] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **loss_kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, MllamaForConditionalGeneration - - >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" - >>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint) - >>> processor = AutoProcessor.from_pretrained(checkpoint) - - >>> prompt = "<|image|>If I had to write a haiku for this one" - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(text=prompt, images=image, return_tensors="pt") - - >>> # Generate - >>> output = model.generate(**inputs, max_new_tokens=15) - - >>> prompt_len = inputs.input_ids.shape[-1] - >>> generated_ids = output[:, prompt_len:] - >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) - >>> print(generated_text) - [', it would be:.\\nA stop sign in Chinatown.\\n'] - ``` - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if pixel_values is not None and cross_attention_states is not None: - raise ValueError( - "`pixel_values` and `cross_attention_states` cannot be provided simultaneously" - ) - - if pixel_values is not None: - if aspect_ratio_ids is None: - raise ValueError( - "`aspect_ratio_ids` must be provided if `pixel_values` is provided" - ) - # get vision tokens from vision model - vision_outputs = self.vision_model( - pixel_values=pixel_values, - aspect_ratio_ids=aspect_ratio_ids, - aspect_ratio_mask=aspect_ratio_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, - ) - cross_attention_states = vision_outputs[0] - cross_attention_states = self.multi_modal_projector( - cross_attention_states - ).reshape( - -1, cross_attention_states.shape[-2], self.hidden_size # type: ignore - ) - - if cross_attention_mask is not None: - cross_attention_mask, full_text_row_masked_out_mask = ( - _prepare_cross_attention_mask( - cross_attention_mask, - num_vision_tokens=self.vision_model.num_patches, - dtype=self.dtype, - ) - ) - else: - full_text_row_masked_out_mask = None - - if cross_attention_mask is not None and cache_position is not None: - cross_attention_mask = cross_attention_mask[:, :, cache_position] - full_text_row_masked_out_mask = full_text_row_masked_out_mask[ - :, :, cache_position - ] - - outputs = self.language_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - past_key_values=past_key_values, - use_cache=use_cache, - inputs_embeds=inputs_embeds, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, - cache_position=cache_position, - logits_to_keep=logits_to_keep, - defer_logits_calculation=True, # enable deferred logits calculation - **loss_kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states, - self.language_model.lm_head.weight, - labels, - _PATCH_OPTS, - **loss_kwargs, - ) - else: - # Temporary fix to calculate the loss in main class, as the model's vocab size may be resized - logits = hidden_states - - if labels is not None: - loss = self.loss_function( - logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs - ) - - if not return_dict: - return (loss,) + outputs if loss is not None else outputs - - return CausalLMOutputWithPast( - loss=loss, - logits=outputs.logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def patch_mllama( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.mllama import modeling_mllama - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_mllama.MllamaForConditionalGeneration - ), f"Expected a MllamaForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - # patch the language model - maybe_model.language_model.forward = MethodType( - cce_forward, maybe_model.language_model - ) - return maybe_model - - modeling_mllama.MllamaForConditionalGeneration.forward = cce_forward_multimodal - - # patch the causal language model - modeling_mllama.MllamaForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py deleted file mode 100644 index 8176a1f0c..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. - -"""Cut Cross Entropy patcher""" - -import transformers -from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl -from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT -from cut_cross_entropy.transformers.phi3 import patch_phi3 -from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT - -from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import ( - patch_cohere, - patch_cohere2, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma -from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import ( - patch_gemma2, - patch_gemma3, - patch_gemma3_text, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import ( - patch_glm, - patch_glm4, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import ( - patch_llama, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import ( - patch_llama4, - patch_llama4_text, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import ( - patch_mistral, - patch_mistral3, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2 import ( - patch_qwen2, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_5_vl import ( - patch_qwen2_5_vl, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_moe import ( - patch_qwen2_moe, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_vl import ( - patch_qwen2_vl, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3 import patch_qwen3 -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3_moe import ( - patch_qwen3_moe, -) - -CUT_CROSS_ENTROPY_MODEL_MAPPING = { - "llama": patch_llama, - "llama4": patch_llama4, - "llama4_text": patch_llama4_text, - "mllama": patch_mllama, - "phi3": patch_phi3, - "gemma": patch_gemma, - "gemma2": patch_gemma2, - "gemma3": patch_gemma3, - "gemma3_text": patch_gemma3_text, - "mistral": patch_mistral, - "mistral3": patch_mistral3, - "qwen2": patch_qwen2, - "qwen2_moe": patch_qwen2_moe, - "qwen2_vl": patch_qwen2_vl, - "qwen2_5_vl": patch_qwen2_5_vl, - "qwen3": patch_qwen3, - "qwen3_moe": patch_qwen3_moe, - "cohere": patch_cohere, - "cohere2": patch_cohere2, - "glm": patch_glm, - "glm4": patch_glm4, -} - - -def cce_patch( - model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig, - impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT, - reduction: str = "mean", - filter_eps: float | str | None = "auto", - accum_e_fp32: bool = False, - accum_c_fp32: bool = False, - filter_e_grad: bool = True, - filter_c_grad: bool = True, - train_only: bool = False, -) -> TransformersModelT | None: - if isinstance(impl, LinearCrossEntropyImpl): - impl = impl.name.lower() - - if impl not in (v.name.lower() for v in LinearCrossEntropyImpl): - raise ValueError(f"Unknown {impl=}") - - if isinstance(model_type_or_model, transformers.PreTrainedModel): - if hasattr(model_type_or_model, "config"): - model_type = getattr( - getattr(model_type_or_model, "config", None), "model_type", None - ) - else: - raise ValueError( - "model_type_or_model is a PreTrainedModel but does not have a config attribute" - ) - elif isinstance(model_type_or_model, transformers.PretrainedConfig): - model_type = model_type_or_model.model_type - else: - model_type = model_type_or_model - - patch_options = PatchOptions( - impl=impl, - reduction=reduction, - filter_eps=filter_eps, - accum_e_fp32=accum_e_fp32, - accum_c_fp32=accum_c_fp32, - filter_e_grad=filter_e_grad, - filter_c_grad=filter_c_grad, - train_only=train_only, - ) - - if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING: - return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type]( - model_type_or_model, patch_options - ) - - raise RuntimeError(f"Unknown model type {model_type}") diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2.py deleted file mode 100644 index 3f6d2b3e9..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Qwen2 CCE patch. The model inherits Llama's modeling code and uses the same forward method.""" - -# pylint: disable=duplicate-code - -from types import MethodType - -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, -) - - -def patch_qwen2( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - from transformers.models.qwen2 import modeling_qwen2 - - # Set the _PATCH_OPTS in the llama patch file - import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch - - llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access - - from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import ( - cce_forward, - ) - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen2.Qwen2ForCausalLM - ), f"Expected a Qwen2ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_qwen2.Qwen2ForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_5_vl.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_5_vl.py deleted file mode 100644 index 16206006f..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_5_vl.py +++ /dev/null @@ -1,246 +0,0 @@ -"""Qwen2.5 VL CCE patch. Adapted from transformers v4.51.2""" - -# pylint: disable=duplicate-code - - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from torch.nn import CrossEntropyLoss -from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VLCausalLMOutputWithPast, -) - -_PATCH_OPTS: PatchOptions | None = None - - -def cce_forward_multimodal( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, -) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration - - >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - - >>> messages = [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "What is shown in this image?"}, - ], - }, - ] - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if inputs_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) - if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() - n_image_features = image_embeds.shape[0] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - - mask = input_ids == self.config.image_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - image_mask = mask_expanded.to(inputs_embeds.device) - - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore - - if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() - n_video_features = video_embeds.shape[0] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - - mask = input_ids == self.config.video_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - video_mask = mask_expanded.to(inputs_embeds.device) - - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore - - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - - # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): - # calculate RoPE index once per generation in the pre-fill stage only - if ( - (cache_position is not None and cache_position[0] == 0) - or self.rope_deltas is None - or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore - ): - position_ids, rope_deltas = self.get_rope_index( - input_ids, - image_grid_thw, - video_grid_thw, - second_per_grid_ts, - attention_mask, - ) - self.rope_deltas = rope_deltas - # then use the prev pre-calculated rope-deltas to get the correct position ids - else: - batch_size, seq_length, _ = inputs_embeds.shape - delta = ( - (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) - if cache_position is not None - else 0 - ) - position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore - position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore - if cache_position is not None: # otherwise `deltas` is an int `0` - delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore - position_ids = position_ids.add(delta) # type: ignore - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore - - outputs = self.model( - input_ids=None, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - logits = None - loss = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states, - self.lm_head.weight, - labels, - _PATCH_OPTS, - ) - else: - logits = self.lm_head(hidden_states) - - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Qwen2_5_VLCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - rope_deltas=self.rope_deltas, - ) - - -def patch_qwen2_5_vl( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - - from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration - ), f"Expected a Qwen2_5_VLForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - return maybe_model - - modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = ( - cce_forward_multimodal - ) - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py deleted file mode 100644 index afe56266e..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py +++ /dev/null @@ -1,178 +0,0 @@ -"""Qwen2 MoE CCE patch. Adapted from transformers v4.51.2""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.models.qwen2_moe.modeling_qwen2_moe import ( - MoeCausalLMOutputWithPast, - MoeModelOutputWithPast, - load_balancing_loss_func, -) -from transformers.utils.deprecation import deprecate_kwarg -from transformers.utils.generic import can_return_tuple - -_PATCH_OPTS: PatchOptions | None = None - - -@can_return_tuple -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **loss_kwargs, -) -> MoeCausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM - - >>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_router_logits = ( - output_router_logits - if output_router_logits is not None - else self.config.output_router_logits - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: MoeModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_router_logits=output_router_logits, - cache_position=cache_position, - ) - - hidden_states = outputs.last_hidden_state - loss = None - logits = None - - if hidden_states is None: - raise ValueError("hidden_states is None") - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **loss_kwargs, - ) - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - aux_loss = None - if output_router_logits: - aux_loss = load_balancing_loss_func( - outputs.router_logits, - self.num_experts, - self.num_experts_per_tok, - attention_mask, - ) - if labels is not None: - loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore - loss.device # type: ignore - ) # make sure to reside in the same device - - return MoeCausalLMOutputWithPast( - loss=loss, - aux_loss=aux_loss, # type: ignore - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - router_logits=outputs.router_logits, - ) - - -def patch_qwen2_moe( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - - from transformers.models.qwen2_moe import modeling_qwen2_moe - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen2_moe.Qwen2MoeForCausalLM - ), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(forward, maybe_model) - - return maybe_model - - modeling_qwen2_moe.Qwen2MoeForCausalLM.forward = forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py deleted file mode 100644 index 79af01cfa..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py +++ /dev/null @@ -1,239 +0,0 @@ -"""Qwen2 VL CCE patch. Adapted from transformers v4.51.2""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from torch.nn import CrossEntropyLoss -from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - Qwen2VLCausalLMOutputWithPast, -) - -_PATCH_OPTS: PatchOptions | None = None - - -def cce_forward_multimodal( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, -) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration - - >>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") - >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") - - >>> messages = [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "What is shown in this image?"}, - ], - }, - ] - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if inputs_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) - if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.get_dtype()) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() - n_image_features = image_embeds.shape[0] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_mask = ( - (input_ids == self.config.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore - - if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() - n_video_features = video_embeds.shape[0] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - video_mask = ( - (input_ids == self.config.video_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore - - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - - # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): - # calculate RoPE index once per generation in the pre-fill stage only - if ( - (cache_position is not None and cache_position[0] == 0) - or self.rope_deltas is None - or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore - ): - position_ids, rope_deltas = self.get_rope_index( - input_ids, image_grid_thw, video_grid_thw, attention_mask - ) - self.rope_deltas = rope_deltas - # then use the prev pre-calculated rope-deltas to get the correct position ids - else: - batch_size, seq_length, _ = inputs_embeds.shape - delta = ( - cache_position[0] + self.rope_deltas - if cache_position is not None - else 0 - ) - position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore - position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore - if cache_position is not None: # otherwise `deltas` is an int `0` - delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore - delta = delta.to(position_ids.device) # type: ignore - position_ids = position_ids.add(delta) # type: ignore - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore - - outputs = self.model( - input_ids=None, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - logits = None - loss = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states, - self.lm_head.weight, - labels, - _PATCH_OPTS, - ) - else: - logits = self.lm_head(hidden_states) - - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Qwen2VLCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - rope_deltas=self.rope_deltas, - ) - - -def patch_qwen2_vl( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - - from transformers.models.qwen2_vl import modeling_qwen2_vl - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen2_vl.Qwen2VLForConditionalGeneration - ), f"Expected a Qwen2VLForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - return maybe_model - - modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = cce_forward_multimodal - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3.py deleted file mode 100644 index 799a4f357..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Qwen3 CCE patch. The model inherits Llama's modeling code and uses the same forward method.""" - -# pylint: disable=duplicate-code - -from types import MethodType - -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, -) - - -def patch_qwen3( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - from transformers.models.qwen3 import modeling_qwen3 - - # Set the _PATCH_OPTS in the llama patch file - import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch - - llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access - - from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import cce_forward - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen3.Qwen3ForCausalLM - ), f"Expected a Qwen3ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_qwen3.Qwen3ForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py deleted file mode 100644 index 90466e64b..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py +++ /dev/null @@ -1,183 +0,0 @@ -"""Qwen3 MoE CCE patch. Adapted from transformers v4.51.2""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.models.qwen3_moe.modeling_qwen3_moe import ( - KwargsForCausalLM, - MoeCausalLMOutputWithPast, - MoeModelOutputWithPast, - load_balancing_loss_func, -) -from transformers.processing_utils import Unpack -from transformers.utils.deprecation import deprecate_kwarg -from transformers.utils.generic import can_return_tuple - -_PATCH_OPTS: PatchOptions | None = None - - -@can_return_tuple -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], -) -> MoeCausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM - - >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_router_logits = ( - output_router_logits - if output_router_logits is not None - else self.config.output_router_logits - ) - - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: MoeModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_router_logits=output_router_logits, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - - if hidden_states is None: - raise ValueError("hidden_states is None") - - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) - - aux_loss = None - if output_router_logits: - aux_loss = load_balancing_loss_func( - outputs.router_logits, - self.num_experts, - self.num_experts_per_tok, - attention_mask, - ) - if labels is not None: - loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore - loss.device # type: ignore - ) # make sure to reside in the same device - - return MoeCausalLMOutputWithPast( - loss=loss, - aux_loss=aux_loss, # type: ignore - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - router_logits=outputs.router_logits, - ) - - -def patch_qwen3_moe( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - - from transformers.models.qwen3_moe import modeling_qwen3_moe - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen3_moe.Qwen3MoeForCausalLM - ), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(forward, maybe_model) - - return maybe_model - - modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/utils.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/utils.py deleted file mode 100644 index b808b9f0d..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/utils.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. - -"""Monkeypatch for apply_lce to add softcap.""" - -import torch -from cut_cross_entropy import linear_cross_entropy -from cut_cross_entropy.transformers.utils import PatchOptions - - -def apply_lce( - e: torch.Tensor, - c: torch.Tensor, - labels: torch.Tensor, - opts: PatchOptions, - bias: torch.Tensor | None = None, - softcap: float | None = None, - **loss_kwargs, -) -> torch.Tensor: - """Monkey patch for apply_lce to support softcap kwarg.""" - num_items_in_batch = loss_kwargs.get("num_items_in_batch", None) - cce_kwargs = opts.to_kwargs() - if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean": - cce_kwargs["reduction"] = "sum" - else: - num_items_in_batch = None - - loss = linear_cross_entropy( - e, - c, - labels.to(e.device), - bias=bias, - shift=True, - softcap=softcap, - **cce_kwargs, - ) - - if num_items_in_batch is not None: - loss = loss / num_items_in_batch - - return loss diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 3b2a455ca..bbc532fb9 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -504,6 +504,9 @@ class ModelLoader: # for some reason, this causes the loss to be off by an order of magnitude # but deepspeed needs this still in bfloat16 bnb_config["bnb_4bit_quant_storage"] = torch.float32 + if self.cfg.model_config_type == "falcon_h1": + # output projection cannot be quantized for Falcon-H1 models + bnb_config["llm_int8_skip_modules"] = ["out_proj"] if self.cfg.bnb_config_kwargs: bnb_config.update(self.cfg.bnb_config_kwargs) @@ -518,6 +521,9 @@ class ModelLoader: # Exclude mamba blocks from int8 quantization for jamba if self.cfg.model_config_type == "jamba": bnb_config["llm_int8_skip_modules"] = ["mamba"] + if self.cfg.model_config_type == "falcon_h1": + # output projection cannot be quantized for Falcon-H1 models + bnb_config["llm_int8_skip_modules"] = ["out_proj"] self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **bnb_config, ) diff --git a/src/axolotl/processing_strategies.py b/src/axolotl/processing_strategies.py index ce9b6a838..080697400 100644 --- a/src/axolotl/processing_strategies.py +++ b/src/axolotl/processing_strategies.py @@ -142,7 +142,7 @@ class ProcessingStrategy: # TODO: check if it's normal to be single image only for common datasets # From observation, it's usually a list of single image but some datasets may have several columns for images # Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages - if len(processed_example[image_key]) > 0: + if len(processed_example[image_key]) > 1: LOG.warning( f"Found {len(processed_example[image_key])} images in a sample. Using the first one." "If you are using a dataset with multiple images per sample, please convert it to use multi-content Messages." diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 819616425..d5dd431c1 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -218,6 +218,7 @@ def execute_training( gradient_accumulation_steps=cfg.gradient_accumulation_steps, ring_attn_func=cfg.ring_attn_func, heads_k_stride=cfg.heads_k_stride, + gather_outputs=cfg.rl is RLType.GRPO, ) ) diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 83a42945b..c809ffc7a 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -46,6 +46,7 @@ _CHAT_TEMPLATES = { "command_a_tool_use": '{{ bos_token }}{%- macro document_turn(documents) -%}\n{# format documents into chat turn #}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n {\n "tool_call_id": "0",\n "results": {\n{% for doc in documents %}\n "{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %},\n {% endif %}\n{% endfor %}\n\n },\n "is_error": null\n }\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}\n{%- macro tool_call_id_to_int(messages, tool_call_id) %}\n{%- set counter = namespace(value=0) %}\n{%- set tool_call_id_seen = namespace(value=false) %}\n{%- for msg in messages %}\n {%- if msg.tool_calls %}\n {%- for tool_call in msg.tool_calls %}\n {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}\n {{ counter.value }}\n {%- set tool_call_id_seen.value = true %}\n {%- endif %}\n {%- set counter.value = counter.value + 1 %}\n {%- endfor %}\n {%- endif %}\n{%- endfor %}\n{%- endmacro %}\n{%- macro format_tool_message(messages, tool_msg) -%}\n{# format tool message #}\n {\n "tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}",\n "results": {\n "0": {{ tool_msg.content|tojson }}\n },\n "is_error": null\n }\n{%- endmacro -%}\n{%- if messages and messages[0][\'role\']|lower == \'system\' %}{%- set developer_preamble = messages[0][\'content\'] %}{% endif %}\n{%- set tool_idx = namespace(value=0) %}\n{%- set tool_ids_seen = namespace(value=[]) %}\n{%- set sent_documents = namespace(value=false) %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n{% if tools or documents %}\n\nYou have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user\'s requests.\n\n## Tool Use\nThink about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.\n\n0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.\n NOTE: You MUST skip this step when you are directly responding to the user\'s request without using any tools.\n\nThen carry out your plan by repeatedly executing the following steps.\n1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields.\n When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.\n2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.\n Every tool call produces a list of results (when a tool call produces no result or a single result, it\'ll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id".\n3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you\'ve figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.\n NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.\n\nYou can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it\'s time to finally respond to the user.\n\n4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user\'s last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.\n{% if enable_citations %}\n\n## Grounding\nImportantly, note that "Reflection" and "Response" above can be grounded.\nGrounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "" and "" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "span" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1".\n{% endif %}\n\n## Available Tools\nHere is the list of tools that you have available to you.\nYou can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.\nEach tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema).\n\n```json\n[\n{% if documents %}\n {"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %}\n\n{% endif %}\n{% for tool in tools %}\n {"name": "{{ tool[\'function\'][\'name\'] }}", "description": "{{tool[\'function\'][\'description\']}}", "parameters": {{ tool[\'function\'][\'parameters\']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %}\n\n{% endfor %}\n]\n```\n\n{% endif %}\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n {%- if message.role|lower == \'system\' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n {%- elif message.role|lower == \'user\' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}\n {%- elif message.role|lower == \'assistant\' or message.role|lower == \'chatbot\' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[\n {% for tc in message.tool_calls %}\n {"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc[\'function\'][\'name\'] }}", "parameters": {{ tc[\'function\'][\'arguments\']|tojson }}}{% if not loop.last %},{% endif %}\n\n {% set tool_idx.value = tool_idx.value + 1 %}\n {% endfor %}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}\n {% elif message.role|lower == \'tool\' and message.tool_call_id not in tool_ids_seen.value %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n{{ format_tool_message(messages, message) }}\n {%- set stopped = namespace(value=false) %}\n {%- for msg in messages[loop.index0 + 1:] %}\n {%- if not stopped.value and msg.role|lower == \'tool\' %},\n{{ format_tool_message(messages, msg) }}\n {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}\n {%- else %}\n {%- set stopped.value = true %}\n {%- endif %}\n {%- endfor %}\n\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>\n {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>', "command_a_rag": '{{ bos_token }}{% set tools = [] %}\n{%- macro document_turn(documents) -%}\n{# format documents into chat turn #}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n {\n "tool_call_id": "0",\n "results": {\n{% for doc in documents %}\n "{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %},\n {% endif %}\n{% endfor %}\n\n },\n "is_error": null\n }\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}\n{%- macro tool_call_id_to_int(messages, tool_call_id) %}\n{%- set counter = namespace(value=0) %}\n{%- set tool_call_id_seen = namespace(value=false) %}\n{%- for msg in messages %}\n {%- if msg.tool_calls %}\n {%- for tool_call in msg.tool_calls %}\n {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}\n {{ counter.value }}\n {%- set tool_call_id_seen.value = true %}\n {%- endif %}\n {%- set counter.value = counter.value + 1 %}\n {%- endfor %}\n {%- endif %}\n{%- endfor %}\n{%- endmacro %}\n{%- macro format_tool_message(messages, tool_msg) -%}\n{# format tool message #}\n {\n "tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}",\n "results": {\n "0": {{ tool_msg.content|tojson }}\n },\n "is_error": null\n }\n{%- endmacro -%}\n{%- if messages and messages[0][\'role\']|lower == \'system\' %}{%- set developer_preamble = messages[0][\'content\'] %}{% endif %}\n{%- set tool_idx = namespace(value=0) %}\n{%- set tool_ids_seen = namespace(value=[]) %}\n{%- set sent_documents = namespace(value=false) %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n{% if tools or documents %}\n\nYou have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user\'s requests.\n\n## Tool Use\nThink about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.\n\n0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.\n NOTE: You MUST skip this step when you are directly responding to the user\'s request without using any tools.\n\nThen carry out your plan by repeatedly executing the following steps.\n1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields.\n When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.\n2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.\n Every tool call produces a list of results (when a tool call produces no result or a single result, it\'ll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id".\n3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you\'ve figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.\n NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.\n\nYou can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it\'s time to finally respond to the user.\n\n4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user\'s last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.\n{% if enable_citations %}\n\n## Grounding\nImportantly, note that "Reflection" and "Response" above can be grounded.\nGrounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "" and "" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "span" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1".\n{% endif %}\n\n## Available Tools\nHere is the list of tools that you have available to you.\nYou can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.\nEach tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema).\n\n```json\n[\n{% if documents %}\n {"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %}\n\n{% endif %}\n{% for tool in tools %}\n {"name": "{{ tool[\'function\'][\'name\'] }}", "description": "{{tool[\'function\'][\'description\']}}", "parameters": {{ tool[\'function\'][\'parameters\']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %}\n\n{% endfor %}\n]\n```\n\n{% endif %}\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n {%- if message.role|lower == \'system\' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n {%- elif message.role|lower == \'user\' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}\n {%- elif message.role|lower == \'assistant\' or message.role|lower == \'chatbot\' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[\n {% for tc in message.tool_calls %}\n {"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc[\'function\'][\'name\'] }}", "parameters": {{ tc[\'function\'][\'arguments\']|tojson }}}{% if not loop.last %},{% endif %}\n\n {% set tool_idx.value = tool_idx.value + 1 %}\n {% endfor %}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}\n {% elif message.role|lower == \'tool\' and message.tool_call_id not in tool_ids_seen.value %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n{{ format_tool_message(messages, message) }}\n {%- set stopped = namespace(value=false) %}\n {%- for msg in messages[loop.index0 + 1:] %}\n {%- if not stopped.value and msg.role|lower == \'tool\' %},\n{{ format_tool_message(messages, msg) }}\n {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}\n {%- else %}\n {%- set stopped.value = true %}\n {%- endif %}\n {%- endfor %}\n\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>\n {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>', "aya": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Aya, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", + "falcon_h1": """'{{bos_token}}\n{%- if tools %}\n {{- \'<|im_start|>system\\n\' }}\n {%- if messages[0].role == \'system\' %}\n {{- messages[0].content + \'\\n\\n\' }}\n {%- endif %}\n {{- "You are a function calling AI model. You are provided with function signature within XML tags. You may call one or more functions to assist with the user query. Don\'t make assumptions about what values to plug into functions.\\n\\n" }}\n {%- for tool in tools %}[{{- tool | tojson }}]{%- endfor %}\n {{- "\\n\\nFor each function call, return a json object with function name and arguments within tags with the following schema:\\n\\n{\'arguments\': , \'name\': }\\n\\n" }}\n{%- else %}\n {%- if messages[0].role == \'system\' %}\n {{- \'<|im_start|>system\\n\' + messages[0].content + \'<|im_end|>\\n\' }}\n {%- endif %}\n{%- endif %}{% for message in messages %}{%- if message.role != \'system\' %}{{\'<|im_start|>\' + message[\'role\'] + \'\n\' + message[\'content\'] + \'<|im_end|>\' + \'\n\'}}{%- endif %}{% endfor %}{% if add_generation_prompt %}{{ \'<|im_start|>assistant\n\' }}{% endif %}'""", } diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index 491cb9877..f429cd2ae 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -174,6 +174,8 @@ class SequenceParallelContextManager: ring_attn_func: Which ring attention function to use. Currently unused. heads_k_stride: Sequence parallelism K head stride size. Passed through to `varlen_llama3` `ring_flash_attn` implementation. + gather_outputs: Whether to gather outputs after model forward pass across the + sequence parallel group. """ def __init__( @@ -183,12 +185,15 @@ class SequenceParallelContextManager: gradient_accumulation_steps: int, ring_attn_func: RingAttnFunc, heads_k_stride: int | None, + gather_outputs: bool, ): self.models = models self.sequence_parallel_degree = sequence_parallel_degree self.gradient_accumulation_steps = gradient_accumulation_steps self.ring_attn_func = ring_attn_func self.heads_k_stride = heads_k_stride + self.gather_outputs = gather_outputs + self._register_ring_attn() # Set distributed info for local rank @@ -277,16 +282,17 @@ class SequenceParallelContextManager: return output - # Register both hooks + # Register hooks for model in self.models: self.hook_handles.append( model.register_forward_pre_hook( sequence_parallel_pre_hook, with_kwargs=True ) ) - self.hook_handles.append( - model.register_forward_hook(sequence_parallel_post_hook) - ) + if self.gather_outputs: + self.hook_handles.append( + model.register_forward_hook(sequence_parallel_post_hook) + ) def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: """Gather sharded outputs from all ranks and reconstruct the full tensor.""" diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 7fb5e1b41..95d97e7a0 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -127,7 +127,7 @@ def pack_parallel( bin_size: int, num_processes: int | None = None, safe_mode: bool = True, - mp_start_method: str | None = "spawn", + mp_start_method: str | None = "fork", ) -> list[list[int]]: """Pack sequences into bins using parallel processing. @@ -266,6 +266,7 @@ class MultipackBatchSampler(BatchSampler): bin_size: int = 200, # The max number of samples that can be packed in a single bin num_processes: int | None = None, # Number of processes for parallel packing safe_mode: bool = True, # Conservative packing to prevent training instability + mp_start_method: str = "fork", **kwargs, # pylint: disable=unused-argument ): super().__init__(sampler, batch_size, drop_last) @@ -278,6 +279,7 @@ class MultipackBatchSampler(BatchSampler): self.bin_size = bin_size self.num_processes = num_processes self.safe_mode = safe_mode + self.mp_start_method = mp_start_method assert isinstance(self.lengths, np.ndarray) @@ -338,8 +340,9 @@ class MultipackBatchSampler(BatchSampler): bin_capacity=self.batch_max_len, group_size=self.group_size, bin_size=self.bin_size, - num_processes=self.num_processes, + num_processes=max(4, self.num_processes) if self.num_processes else 4, safe_mode=self.safe_mode, + mp_start_method=self.mp_start_method, ) # Map bin indices back to original indices diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index c698fc3b6..4031742cd 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -393,6 +393,12 @@ class AxolotlInputConfig( default=None, json_schema_extra={"description": "Whether to pack samples sequentially"}, ) + sample_packing_mp_start_method: str | None = Field( + default=None, + json_schema_extra={ + "description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'" + }, + ) eval_sample_packing: bool | None = Field( default=None, json_schema_extra={ diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index bfef14d53..67fc7a8a7 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -54,6 +54,7 @@ class ChatTemplate(str, Enum): jinja = "jinja" qwen_25 = "qwen_25" qwen3 = "qwen3" + falcon_h1 = "falcon_h1" tokenizer_default = "tokenizer_default" exaone = "exaone" metharme = "metharme" diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 633dffde5..554a55abc 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -467,6 +467,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): sequential=cfg.sample_packing_sequentially, drop_last=True, num_processes=cfg.dataset_processes, + mp_start_method=cfg.sample_packing_mp_start_method or "fork", ) data_loader = DataLoader(