feat: add llama4 multimodal (#2499)
* feat: add llama4 multimodal * feat: add torchvision to base docker * just use latest torchvision --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -29,7 +29,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
|||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ format:
|
|||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
- [Mllama](#sec-mllama)
|
- [Mllama](#sec-mllama)
|
||||||
|
- [Llama4](#sec-llama4)
|
||||||
- [Pixtral](#sec-pixtral)
|
- [Pixtral](#sec-pixtral)
|
||||||
- [Llava-1.5](#sec-llava-15)
|
- [Llava-1.5](#sec-llava-15)
|
||||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||||
@@ -63,6 +64,14 @@ base_model: meta-llama/Llama-3.2-11B-Vision-Instruct
|
|||||||
chat_template: llama3_2_vision
|
chat_template: llama3_2_vision
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Llama4 {#sec-llama4}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||||
|
|
||||||
|
chat_template: llama4
|
||||||
|
```
|
||||||
|
|
||||||
### Pixtral {#sec-pixtral}
|
### Pixtral {#sec-pixtral}
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
@@ -268,6 +268,7 @@ def get_processing_strategy(
|
|||||||
)
|
)
|
||||||
if chat_template_type in [
|
if chat_template_type in [
|
||||||
"llama3_2_vision",
|
"llama3_2_vision",
|
||||||
|
"llama4",
|
||||||
"llava",
|
"llava",
|
||||||
"mistral_v7_tekken",
|
"mistral_v7_tekken",
|
||||||
"pixtral",
|
"pixtral",
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -36,6 +36,7 @@ from transformers import (
|
|||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
Gemma3ForConditionalGeneration,
|
Gemma3ForConditionalGeneration,
|
||||||
GPTQConfig,
|
GPTQConfig,
|
||||||
|
Llama4ForConditionalGeneration,
|
||||||
LlavaForConditionalGeneration,
|
LlavaForConditionalGeneration,
|
||||||
Mistral3ForConditionalGeneration,
|
Mistral3ForConditionalGeneration,
|
||||||
MllamaForConditionalGeneration,
|
MllamaForConditionalGeneration,
|
||||||
@@ -76,6 +77,7 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
MULTIMODAL_AUTO_MODEL_MAPPING = {
|
MULTIMODAL_AUTO_MODEL_MAPPING = {
|
||||||
"mllama": MllamaForConditionalGeneration,
|
"mllama": MllamaForConditionalGeneration,
|
||||||
|
"llama4": Llama4ForConditionalGeneration,
|
||||||
"llava": LlavaForConditionalGeneration,
|
"llava": LlavaForConditionalGeneration,
|
||||||
"qwen2_vl": Qwen2VLForConditionalGeneration,
|
"qwen2_vl": Qwen2VLForConditionalGeneration,
|
||||||
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
|
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ class ChatTemplate(str, Enum):
|
|||||||
llama3 = "llama3" # pylint: disable=invalid-name
|
llama3 = "llama3" # pylint: disable=invalid-name
|
||||||
llama4 = "llama4" # pylint: disable=invalid-name
|
llama4 = "llama4" # pylint: disable=invalid-name
|
||||||
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||||
|
llama4 = "llama4" # pylint: disable=invalid-name
|
||||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||||
phi_35 = "phi_35" # pylint: disable=invalid-name
|
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||||
|
|||||||
Reference in New Issue
Block a user