Compare commits

...

13 Commits

Author SHA1 Message Date
NanoCode012
1f2f285173 fix: missing key in enum 2025-07-03 13:46:16 +08:00
NanoCode012
98e912e416 feat: add custom processing strategy for phi35 vl 2025-07-03 13:46:16 +08:00
NanoCode012
e1528fb381 feat: add phi_35_vl support 2025-07-03 13:46:16 +08:00
NanoCode012
8ae5a2311b feat: update handling for mistraltokenizer decode and multiprocessing pickling fix (#2790)
* feat: update handling for mistraltokenizer decode

* fix: update mistral common package version

* fix: to use correct release

* fix triton path

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-07-02 08:07:18 -04:00
NanoCode012
6383630155 Fix: tokenize stall due to not shuffling dataset (#2845)
* fix: shuffle dataset even if only one to fix tokenize stall

* fix: warn if shuffling merged with curriculum sampling

* chore: refactor
2025-07-02 08:06:00 -04:00
Vincenzo di Cicco
f2b352f2e5 Add sample_packing_sequentially to trainer args (#2853) [skip ci] 2025-07-02 08:05:35 -04:00
NanoCode012
bf5928d0ee feat(doc): update docker tag examples (#2851) [skip ci]
* feat(doc): update docker tag examples

* chore: comment
2025-07-02 08:05:01 -04:00
Dhruv Mullick
d1224db8f4 Decouple generate_during_eval from wandb to support other visualizers (#2849) [skip ci]
* Add generate_during_eval for mlflow for dpo

* Decouple generate_during_eval from wandb
2025-07-02 08:04:40 -04:00
mhenrichsen
327b4e48e9 Add installation instructions for pip and Docker to README.md (#2854)
* Add installation instructions for pip and Docker to README.md

* Enhance README.md with Docker installation guidance for improved setup reliability.
2025-07-02 09:03:52 +02:00
Dan Saunders
35fdbce102 Ensure device mesh patching is applied (#2842)
* move patches; make patch stronger

* fix broken tests

* guard sequence_parallel_degree comparison against none

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-06-29 22:16:32 -04:00
Wing Lian
cb811f8bf1 upgrade to flash-attn 2.8.0.post2 (#2828)
* upgrade to flash-attn 2.8.0.post2

* use cu126 with torch 2.6

* seems vllm 0.8.5.post1 not compatible with cuda12.6.3 and torch 2.6

* cu126 + torch 2.6 as the default

* use cu126 for multigpu w torch 2.6 too

* drop vllm for now from ci for now
2025-06-29 22:11:16 -04:00
Wing Lian
7563e1bd30 set a different triton cache for each test to avoid blocking writes to cache (#2843)
* set a different triton cache for each test to avoid blocking writes to cache

* set log level

* disable debug logging for filelock
2025-06-29 22:05:21 -04:00
Wing Lian
81893c775c Accelerate 1.8.1 and BNB 0.46.0 update (#2815)
* update accelerate to v1.8.0

* update bnb also

* fix multigpu ci timeout

* fix test set size

* use latest accelerate 1.8.1

* disable default dtype
2025-06-28 15:29:19 -04:00
34 changed files with 190 additions and 63 deletions

View File

@@ -20,12 +20,11 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras: vllm
is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -88,8 +87,8 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
@@ -146,8 +145,8 @@ jobs:
strategy:
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:

View File

@@ -26,11 +26,11 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras: vllm
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 124

View File

@@ -195,12 +195,12 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: vllm
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -247,8 +247,8 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
@@ -311,7 +311,7 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: vllm
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -59,6 +59,8 @@ Features:
### Installation
#### Using pip
```bash
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
@@ -68,6 +70,13 @@ axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
```
#### Using Docker
Installing with Docker can be less error prone than installing in your own environment.
```bash
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
```
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
### Your First Fine-tune

View File

@@ -32,6 +32,8 @@ df_args = {
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
}
dockerfile_contents = df_template.render(**df_args)

View File

@@ -37,7 +37,3 @@ RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
pip3 install flash-attn==2.7.4.post1; \
fi

View File

@@ -34,7 +34,3 @@ RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
uv pip install --no-build-isolation flash-attn==2.7.4.post1; \
fi

View File

@@ -9,7 +9,7 @@ format:
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
::: {.callout-important}
For Blackwell GPUs, please use the tags with Pytorch 2.7.1 and CUDA 12.8.
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
:::
## Base
@@ -34,6 +34,7 @@ Tags examples:
- `main-base-py3.11-cu128-2.7.1`
- `main-base-py3.11-cu126-2.7.1`
- `main-base-py3.11-cu126-2.6.0`
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
@@ -73,13 +74,15 @@ There may be some extra tags appended to the image, like `-vllm` which installs
Tags examples:
- `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu128-2.7.1`
- `main-py3.11-cu126-2.7.1`
- `main-py3.11-cu126-2.6.0`
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1`
- `0.9.2`
- `0.10.1`
## Cloud

View File

@@ -16,6 +16,7 @@ format:
- [Gemma-3](#sec-gemma-3)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [Phi3-V](#sec-phi3-v)
## Usage
@@ -126,6 +127,15 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### Phi3-V {#sec-phi3-v}
```yaml
base_model: microsoft/Phi-3.5-vision-instruct
trust_remote_code: true
chat_template: phi_35_vl
```
## Dataset Format
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.4
bitsandbytes==0.46.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
@@ -15,7 +15,7 @@ huggingface_hub==0.32.2
peft==0.15.2
transformers==4.52.4
tokenizers>=0.21.1
accelerate==1.7.0
accelerate==1.8.1
datasets==3.6.0
deepspeed>=0.17.0
trl==0.18.2
@@ -68,4 +68,4 @@ schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3
mistral-common==1.6.0
mistral-common==1.6.3

View File

@@ -111,9 +111,9 @@ def get_package_version():
extras_require = {
"flash-attn": ["flash-attn==2.7.4.post1"],
"flash-attn": ["flash-attn==2.8.0.post2"],
"ring-flash-attn": [
"flash-attn==2.7.4.post1",
"flash-attn==2.8.0.post2",
"ring-flash-attn>=0.1.4",
"yunchang==0.6.0",
],

View File

@@ -253,6 +253,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing
)
if self.cfg.sample_packing_sequentially is not None:
training_arguments_kwargs["sample_packing_sequentially"] = (
self.cfg.sample_packing_sequentially
)
if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs["sample_packing_bin_size"] = (
self.cfg.sample_packing_bin_size

View File

@@ -28,7 +28,7 @@ class DPOStrategy:
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
if cfg.dpo_padding_free is not None:

View File

@@ -1,6 +1,7 @@
"""Shared constants for axolotl.loaders module"""
from transformers import (
AutoModelForCausalLM,
Gemma3ForConditionalGeneration,
Llama4ForConditionalGeneration,
LlavaForConditionalGeneration,
@@ -18,4 +19,6 @@ MULTIMODAL_AUTO_MODEL_MAPPING = {
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
"mistral3": Mistral3ForConditionalGeneration,
"gemma3": Gemma3ForConditionalGeneration,
# phi3_v modeling code is not available in transformers yet
"phi3_v": AutoModelForCausalLM,
}

View File

@@ -65,6 +65,7 @@ class PatchManager:
self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch()
self._apply_gemma3_conditional_generation_forward_patch()
self._apply_sequence_parallel_patches()
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
@@ -231,6 +232,17 @@ class PatchManager:
patch_gemma3_conditional_generation_forward()
def _apply_sequence_parallel_patches(self):
"""Apply sequence parallelism patches."""
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
from axolotl.monkeypatch.ring_attn.patch import (
patch_prepare_data_loader,
patch_prepare_device_mesh,
)
patch_prepare_data_loader()
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
def _patch_attention(self):
"""Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):

View File

@@ -152,7 +152,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
def patch_prepare_data_loader():
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
Raies:
Raises:
RuntimeError: If source code to patch does not exist.
"""
original_fn = accelerate.data_loader.prepare_data_loader
@@ -168,23 +168,34 @@ def patch_prepare_data_loader():
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
)
items_to_import = []
for item in dir(accelerate.data_loader):
if item in patched_source:
items_to_import.append(item)
# Create a new function from the patched source
namespace = {}
exec( # pylint: disable=exec-used # nosec B102
patched_source, accelerate.data_loader.__dict__, namespace
f"from accelerate.data_loader import ({', '.join(items_to_import)})",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
patched_source, globals(), namespace
)
patched_function = namespace["prepare_data_loader"]
accelerate.data_loader.prepare_data_loader = patched_function
patched_function = namespace["prepare_data_loader"]
original_fn.__code__ = patched_function.__code__
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
def patch_prepare_device_mesh(sequence_parallel_degree: int):
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
that includes sequence parallelism with the specified degree.
Args:
sequence_parallel_degree (int): The degree of sequence parallelism to use.
sequence_parallel_degree: The degree of sequence parallelism to use.
fsdp: Whether to use FSDP.
"""
def _prepare_device_mesh(self):
@@ -207,12 +218,14 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int):
)
device_ids = list(range(world_size))
# Note that we use "cp" instead of "sp" to match the PyTorch native "context
# parallelism" implementation naming
# NOTE: We use "cp" instead of "sp" to match the PyTorch native "context
# parallelism" implementation naming.
# NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we
# only use "fsdp" and "cp" for the device mesh.
return dist.DeviceMesh(
"cuda",
torch.tensor(device_ids).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"),
)
# Replace the original method with our new method

View File

@@ -264,6 +264,23 @@ class Gemma3ProcessingStrategy(ProcessingStrategy):
return labels
class Phi35VLProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Phi-3.5-vision-instruct"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.image_token = "<|image|>" # nosec
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
self.image_token
)
def get_processing_strategy(
processor: ProcessorMixin,
chat_template,
@@ -279,6 +296,10 @@ def get_processing_strategy(
return Gemma3ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
if chat_template_type == "phi_35_vl":
return Phi35VLProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
if chat_template_type in [
"llama3_2_vision",
"llama4",

View File

@@ -223,8 +223,9 @@ def execute_training(
)
LOG.info("Starting trainer...")
if cfg.bf16:
torch.set_default_dtype(torch.bfloat16)
# TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers
# if cfg.bf16:
# torch.set_default_dtype(torch.bfloat16)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -32,6 +32,7 @@ _CHAT_TEMPLATES = {
"llava": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}",
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
"phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}",
"phi_35_vl": "{% set image_count = namespace(value=0) %}{% for message in messages %}{{'<|' + message['role'] + '|>\n' }}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% set message_images = [] %}{% set message_text = [] %}{% for chunk in message['content'] %}{% if chunk['type'] == 'image' or 'image' in chunk or 'image_url' in chunk %}{% set image_count.value = image_count.value + 1 %}{% set _ = message_images.append('<|image_' + image_count.value|string + '|>\n') %}{% elif chunk['type'] == 'text' %}{% set _ = message_text.append(chunk['text']) %}{% endif %}{% endfor %}{{ message_images | join('') }}{{ message_text | join('') }}{% endif %}{{ '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
"phi_4": "{% set system_message = 'You are Phi, a language model trained by Microsoft to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> {Thought section} </think> {Solution section}. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines:' -%}{%- if messages and messages[0]['role'] == 'system' -%}{%- set system_message = messages[0]['content'] -%}{%- set messages = messages[1:] -%}{%- endif -%}<|im_start|>system<|im_sep|>{{ system_message }}<|im_end|>{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'assistant') %}{{'<|im_start|>assistant<|im_sep|>'}}{% generation %}{{message['content'] + '<|im_end|>'}}{% endgeneration %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}",
"deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<User>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<Assistant>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<Assistant>' }}{% endif %}",
"deepseek_v3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<User>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- else %}{{'<Assistant>' + message['content'] + '<tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- endif %}{%- endfor %}{{'<tool▁calls▁end><end▁of▁sentence>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<tool▁outputs▁end>' + message['content'] + '<end▁of▁sentence>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<Assistant>' + content + '<end▁of▁sentence>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<tool▁outputs▁begin><tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<tool▁outputs▁end>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<Assistant>'}}{% endif %}",

View File

@@ -12,8 +12,6 @@ from transformers.utils import ModelOutput
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn,
update_ring_attn_params,
)
@@ -238,12 +236,6 @@ class SequenceParallelContextManager:
ring_attn_func=self.ring_attn_func,
)
# Patches for accelerate functionality
patch_prepare_data_loader()
patch_prepare_device_mesh(
sequence_parallel_degree=self.sequence_parallel_degree
)
def _register_model_hooks(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):

View File

@@ -524,13 +524,24 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
Merged dataset.
"""
if len(datasets) == 1:
return datasets[0]
ds = datasets[0]
# Do not shuffle if curriculum sampling is enabled
if cfg.curriculum_sampling:
return ds
return ds.shuffle(seed=cfg.seed)
LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets)
if cfg.shuffle_merged_datasets:
LOG.debug("Shuffling merged datasets...")
if cfg.curriculum_sampling:
LOG.warning(
"Shuffling merged datasets with curriculum sampling is not recommended. "
"This will randomize the order of samples."
)
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
else:
LOG.debug("Not shuffling merged datasets.")

View File

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
import numpy as np
from huggingface_hub import hf_hub_download
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer
from torch import Tensor
from transformers.utils import PaddingStrategy
@@ -251,10 +251,13 @@ class HFMistralTokenizer:
token_ids = [token_ids]
if skip_special_tokens:
return self._mistral.instruct_tokenizer.tokenizer.decode(token_ids)
return self._mistral.instruct_tokenizer.tokenizer.decode(
token_ids, special_token_policy=SpecialTokenPolicy.IGNORE
)
# to_string returns a string with special tokens
return self._mistral.instruct_tokenizer.tokenizer.to_string(token_ids)
return self._mistral.instruct_tokenizer.tokenizer.decode(
token_ids, special_token_policy=SpecialTokenPolicy.KEEP
)
def _create_mistral_chat_completion_request(
self, conversation: list[dict], tools: list[dict] | None = None

View File

@@ -146,6 +146,7 @@ class AxolotlInputConfig(
dpo_label_smoothing: float | None = None
dpo_norm_loss: bool | None = None
dpo_padding_free: bool | None = None
dpo_generate_during_eval: bool | None = None
datasets: (
Annotated[

View File

@@ -48,6 +48,8 @@ class ChatTemplate(str, Enum):
llama4 = "llama4"
phi_3 = "phi_3"
phi_35 = "phi_35"
phi_35_vl = "phi_35_vl"
phi_4 = "phi_4"
deepseek_v2 = "deepseek_v2"
deepseek_v3 = "deepseek_v3"
jamba = "jamba"

View File

@@ -4,12 +4,14 @@ shared pytest fixtures
import functools
import importlib
import logging
import os
import shutil
import sys
import tempfile
import time
from pathlib import Path
from pathlib import Path, PosixPath
from typing import Generator
import datasets
import pytest
@@ -24,6 +26,8 @@ from tests.hf_offline_utils import (
hf_offline_context,
)
logging.getLogger("filelock").setLevel(logging.CRITICAL)
def retry_on_request_exceptions(max_retries=3, delay=1):
# pylint: disable=duplicate-code
@@ -411,7 +415,7 @@ def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
@pytest.fixture
def temp_dir():
def temp_dir() -> Generator[str, None, None]:
# Create a temporary directory
_temp_dir = tempfile.mkdtemp()
yield _temp_dir
@@ -419,6 +423,11 @@ def temp_dir():
shutil.rmtree(_temp_dir)
@pytest.fixture(scope="function", autouse=True)
def unique_triton_cache_dir(temp_dir: str | PosixPath) -> None:
os.environ["TRITON_CACHE_DIR"] = str(temp_dir) + "/.triton/cache"
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers import Trainer

View File

@@ -54,6 +54,7 @@ class TestSequenceParallelism:
"micro_batch_size": micro_batch_size,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",

View File

@@ -54,6 +54,7 @@ class TestPackedFlex:
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",

View File

@@ -309,6 +309,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
@@ -400,6 +401,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",

View File

@@ -38,12 +38,13 @@ class TestMultiGPUEval:
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.004,
"val_set_size": 0.05,
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca",
"split": "train[:5%]",
},
],
"num_epochs": 1,
@@ -51,6 +52,7 @@ class TestMultiGPUEval:
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
@@ -107,12 +109,13 @@ class TestMultiGPUEval:
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.0004,
"val_set_size": 0.01,
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca",
"split": "train[:5%]",
},
],
"num_epochs": 1,
@@ -120,6 +123,7 @@ class TestMultiGPUEval:
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",

View File

@@ -64,6 +64,7 @@ class TestMultiGPUGemma3:
},
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.0001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",

View File

@@ -62,6 +62,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
@@ -127,6 +128,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": gradient_accumulation_steps,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
@@ -200,6 +202,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -278,6 +281,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -340,6 +344,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": gradient_accumulation_steps,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
@@ -412,6 +417,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
@@ -491,6 +497,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_8bit",
"lr_scheduler": "cosine",
@@ -573,6 +580,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
@@ -669,6 +677,7 @@ class TestMultiGPULlama:
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
@@ -743,6 +752,7 @@ class TestMultiGPULlama:
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
@@ -817,6 +827,7 @@ class TestMultiGPULlama:
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",

View File

@@ -46,6 +46,7 @@ class TestMultiGPUQwen2:
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",

View File

@@ -48,6 +48,7 @@ class TestMultiGPURay:
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
@@ -107,6 +108,7 @@ class TestMultiGPURay:
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",

View File

@@ -396,7 +396,7 @@ def test_model_architecture(model_config):
# pylint: disable=duplicate-code
def test_kernel_training_integration():
def test_kernel_training_integration(temp_dir):
"""Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer
@@ -426,6 +426,14 @@ def test_kernel_training_integration():
}
)
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Load model
model, _, _ = load_model_and_tokenizer(cfg=cfg)
@@ -505,7 +513,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
assert found_patched_attn
def test_kernel_training_integration_dropout_non_zero():
def test_kernel_training_integration_dropout_non_zero(temp_dir):
"""Test model loading with dropout non-zero should not patch."""
from axolotl.cli.utils import load_model_and_tokenizer
@@ -533,6 +541,14 @@ def test_kernel_training_integration_dropout_non_zero():
}
)
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Get original attention class
attention_cls = get_attention_cls_from_config(cfg)