feat: add internvl3_5 (#3141) [skip-ci]
* feat: add internvl3_5 * fix: add timm instructions * chore: add kimi-linear to cce doc * feat: update internvl example * chore: pin revision * chore: remove from multipack * fix: add to multimodal array * fix: internvl use hf version * feat: update cce * chore: lint * fix: list for image_size * chore: add docs vram usage * feat: enable cce * fix: no need trust remote code * fix: inconsistent timm version
This commit is contained in:
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@242b245"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -54,6 +54,7 @@ plugins:
|
||||
- granitemoehybrid
|
||||
- hunyuan_v1_dense
|
||||
- hunyuan_v1_moe
|
||||
- internvl
|
||||
- kimi_linear
|
||||
- lfm2
|
||||
- lfm2_moe
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@242b245"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -79,7 +79,11 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
||||
and hasattr(model_config, "vision_config")
|
||||
and hasattr(model_config.vision_config, "image_size")
|
||||
):
|
||||
cfg.image_size = model_config.vision_config.image_size
|
||||
image_size = model_config.vision_config.image_size
|
||||
if isinstance(image_size, list):
|
||||
cfg.image_size = tuple(image_size)
|
||||
else:
|
||||
cfg.image_size = image_size
|
||||
LOG.debug(f"Loaded image size: {cfg.image_size} from model config")
|
||||
|
||||
quant_config_exists = (
|
||||
|
||||
@@ -8,6 +8,7 @@ from PIL.Image import Resampling
|
||||
from torch import Tensor, zeros_like
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.image_utils import load_image
|
||||
from transformers.models.internvl import InternVLProcessor
|
||||
from transformers.models.smolvlm import SmolVLMProcessor
|
||||
from transformers.models.voxtral import VoxtralProcessor
|
||||
|
||||
@@ -454,6 +455,37 @@ class Mistral3ProcessingStrategy(ProcessingStrategy):
|
||||
return labels
|
||||
|
||||
|
||||
class InternVLProcessingStrategy(ProcessingStrategy):
|
||||
"""Processing Strategy class for InternVL"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor: ProcessorMixin,
|
||||
chat_template: Optional[str] = None,
|
||||
image_size: int | tuple[int, int] | None = None,
|
||||
image_resize_algorithm: Resampling | None = None,
|
||||
):
|
||||
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||
|
||||
if not hasattr(processor, "image_ids"):
|
||||
raise ValueError("'image_ids' missing from InternVL Processor.")
|
||||
|
||||
self.image_token_ids = processor.image_ids
|
||||
|
||||
def process_labels(self, input_ids):
|
||||
labels = input_ids.clone()
|
||||
|
||||
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||
|
||||
for ids in self.image_token_ids:
|
||||
labels[labels == ids] = -100
|
||||
|
||||
# Note: Check if need to mask 'video_token' as it gets converted to
|
||||
# image patches during media processing
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def get_processing_strategy(
|
||||
processor: ProcessorMixin,
|
||||
chat_template,
|
||||
@@ -501,6 +533,11 @@ def get_processing_strategy(
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(processor, InternVLProcessor):
|
||||
return InternVLProcessingStrategy(
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
# llama3_2_vision, llama4, llava
|
||||
# mistral_v7_tekken, pixtral, lfm2vl
|
||||
return ProcessingStrategy(
|
||||
|
||||
Reference in New Issue
Block a user