Compare commits

..

2 Commits

Author SHA1 Message Date
mhenrhcsen
8eba033dc4 fix: correct attention class retrieval for gemma3n model in lora_kernels.py 2025-06-27 19:30:09 +02:00
mhenrhcsen
a9c0f43202 fix: update attention class import logic for gemma3n model 2025-06-27 19:27:36 +02:00
35 changed files with 68 additions and 191 deletions

View File

@@ -20,11 +20,12 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras: vllm
is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -87,8 +88,8 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
@@ -145,8 +146,8 @@ jobs:
strategy:
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:

View File

@@ -26,11 +26,11 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 124

View File

@@ -195,12 +195,12 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
axolotl_extras: vllm
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -247,8 +247,8 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
@@ -311,7 +311,7 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
axolotl_extras: vllm
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -59,8 +59,6 @@ Features:
### Installation
#### Using pip
```bash
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
@@ -70,13 +68,6 @@ axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
```
#### Using Docker
Installing with Docker can be less error prone than installing in your own environment.
```bash
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
```
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
### Your First Fine-tune

View File

@@ -32,8 +32,6 @@ df_args = {
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
}
dockerfile_contents = df_template.render(**df_args)

View File

@@ -37,3 +37,7 @@ RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
pip3 install flash-attn==2.7.4.post1; \
fi

View File

@@ -34,3 +34,7 @@ RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
uv pip install --no-build-isolation flash-attn==2.7.4.post1; \
fi

View File

@@ -9,7 +9,7 @@ format:
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
::: {.callout-important}
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
For Blackwell GPUs, please use the tags with Pytorch 2.7.1 and CUDA 12.8.
:::
## Base
@@ -34,7 +34,6 @@ Tags examples:
- `main-base-py3.11-cu128-2.7.1`
- `main-base-py3.11-cu126-2.7.1`
- `main-base-py3.11-cu126-2.6.0`
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
@@ -74,15 +73,13 @@ There may be some extra tags appended to the image, like `-vllm` which installs
Tags examples:
- `main-py3.11-cu128-2.7.1`
- `main-py3.11-cu126-2.7.1`
- `main-py3.11-cu126-2.6.0`
- `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1`
- `0.10.1`
- `0.9.2`
## Cloud

View File

@@ -16,7 +16,6 @@ format:
- [Gemma-3](#sec-gemma-3)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [Phi3-V](#sec-phi3-v)
## Usage
@@ -127,15 +126,6 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### Phi3-V {#sec-phi3-v}
```yaml
base_model: microsoft/Phi-3.5-vision-instruct
trust_remote_code: true
chat_template: phi_35_vl
```
## Dataset Format
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.46.0
bitsandbytes==0.45.4
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
@@ -15,7 +15,7 @@ huggingface_hub==0.32.2
peft==0.15.2
transformers==4.52.4
tokenizers>=0.21.1
accelerate==1.8.1
accelerate==1.7.0
datasets==3.6.0
deepspeed>=0.17.0
trl==0.18.2
@@ -68,4 +68,4 @@ schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3
mistral-common==1.6.3
mistral-common==1.6.0

View File

@@ -111,9 +111,9 @@ def get_package_version():
extras_require = {
"flash-attn": ["flash-attn==2.8.0.post2"],
"flash-attn": ["flash-attn==2.7.4.post1"],
"ring-flash-attn": [
"flash-attn==2.8.0.post2",
"flash-attn==2.7.4.post1",
"ring-flash-attn>=0.1.4",
"yunchang==0.6.0",
],

View File

@@ -253,10 +253,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing
)
if self.cfg.sample_packing_sequentially is not None:
training_arguments_kwargs["sample_packing_sequentially"] = (
self.cfg.sample_packing_sequentially
)
if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs["sample_packing_bin_size"] = (
self.cfg.sample_packing_bin_size

View File

@@ -28,7 +28,7 @@ class DPOStrategy:
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
if cfg.dpo_padding_free is not None:

View File

@@ -1,7 +1,6 @@
"""Shared constants for axolotl.loaders module"""
from transformers import (
AutoModelForCausalLM,
Gemma3ForConditionalGeneration,
Llama4ForConditionalGeneration,
LlavaForConditionalGeneration,
@@ -19,6 +18,4 @@ MULTIMODAL_AUTO_MODEL_MAPPING = {
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
"mistral3": Mistral3ForConditionalGeneration,
"gemma3": Gemma3ForConditionalGeneration,
# phi3_v modeling code is not available in transformers yet
"phi3_v": AutoModelForCausalLM,
}

View File

@@ -65,7 +65,6 @@ class PatchManager:
self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch()
self._apply_gemma3_conditional_generation_forward_patch()
self._apply_sequence_parallel_patches()
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
@@ -232,17 +231,6 @@ class PatchManager:
patch_gemma3_conditional_generation_forward()
def _apply_sequence_parallel_patches(self):
"""Apply sequence parallelism patches."""
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
from axolotl.monkeypatch.ring_attn.patch import (
patch_prepare_data_loader,
patch_prepare_device_mesh,
)
patch_prepare_data_loader()
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
def _patch_attention(self):
"""Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):

View File

@@ -156,8 +156,12 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
if model_type == "gemma3n":
module = __import__(module_path, fromlist=[f"{model_cls_prefix}TextAttention"])
attention_cls = getattr(module, f"{model_cls_prefix}TextAttention")
else:
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
return attention_cls
except (ImportError, AttributeError) as e:

View File

@@ -152,7 +152,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
def patch_prepare_data_loader():
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
Raises:
Raies:
RuntimeError: If source code to patch does not exist.
"""
original_fn = accelerate.data_loader.prepare_data_loader
@@ -168,34 +168,23 @@ def patch_prepare_data_loader():
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
)
items_to_import = []
for item in dir(accelerate.data_loader):
if item in patched_source:
items_to_import.append(item)
# Create a new function from the patched source
namespace = {}
exec( # pylint: disable=exec-used # nosec B102
f"from accelerate.data_loader import ({', '.join(items_to_import)})",
globals(),
patched_source, accelerate.data_loader.__dict__, namespace
)
exec( # pylint: disable=exec-used # nosec B102
patched_source, globals(), namespace
)
patched_function = namespace["prepare_data_loader"]
original_fn.__code__ = patched_function.__code__
accelerate.data_loader.prepare_data_loader = patched_function
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
def patch_prepare_device_mesh(sequence_parallel_degree: int):
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
that includes sequence parallelism with the specified degree.
Args:
sequence_parallel_degree: The degree of sequence parallelism to use.
fsdp: Whether to use FSDP.
sequence_parallel_degree (int): The degree of sequence parallelism to use.
"""
def _prepare_device_mesh(self):
@@ -218,14 +207,12 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False)
)
device_ids = list(range(world_size))
# NOTE: We use "cp" instead of "sp" to match the PyTorch native "context
# parallelism" implementation naming.
# NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we
# only use "fsdp" and "cp" for the device mesh.
# Note that we use "cp" instead of "sp" to match the PyTorch native "context
# parallelism" implementation naming
return dist.DeviceMesh(
"cuda",
torch.tensor(device_ids).reshape(mesh_shape),
mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"),
mesh_dim_names=("dp", "cp"),
)
# Replace the original method with our new method

View File

@@ -264,23 +264,6 @@ class Gemma3ProcessingStrategy(ProcessingStrategy):
return labels
class Phi35VLProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Phi-3.5-vision-instruct"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.image_token = "<|image|>" # nosec
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
self.image_token
)
def get_processing_strategy(
processor: ProcessorMixin,
chat_template,
@@ -296,10 +279,6 @@ def get_processing_strategy(
return Gemma3ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
if chat_template_type == "phi_35_vl":
return Phi35VLProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
if chat_template_type in [
"llama3_2_vision",
"llama4",

View File

@@ -223,9 +223,8 @@ def execute_training(
)
LOG.info("Starting trainer...")
# TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers
# if cfg.bf16:
# torch.set_default_dtype(torch.bfloat16)
if cfg.bf16:
torch.set_default_dtype(torch.bfloat16)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -32,7 +32,6 @@ _CHAT_TEMPLATES = {
"llava": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}",
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
"phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}",
"phi_35_vl": "{% set image_count = namespace(value=0) %}{% for message in messages %}{{'<|' + message['role'] + '|>\n' }}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% set message_images = [] %}{% set message_text = [] %}{% for chunk in message['content'] %}{% if chunk['type'] == 'image' or 'image' in chunk or 'image_url' in chunk %}{% set image_count.value = image_count.value + 1 %}{% set _ = message_images.append('<|image_' + image_count.value|string + '|>\n') %}{% elif chunk['type'] == 'text' %}{% set _ = message_text.append(chunk['text']) %}{% endif %}{% endfor %}{{ message_images | join('') }}{{ message_text | join('') }}{% endif %}{{ '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
"phi_4": "{% set system_message = 'You are Phi, a language model trained by Microsoft to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> {Thought section} </think> {Solution section}. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines:' -%}{%- if messages and messages[0]['role'] == 'system' -%}{%- set system_message = messages[0]['content'] -%}{%- set messages = messages[1:] -%}{%- endif -%}<|im_start|>system<|im_sep|>{{ system_message }}<|im_end|>{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'assistant') %}{{'<|im_start|>assistant<|im_sep|>'}}{% generation %}{{message['content'] + '<|im_end|>'}}{% endgeneration %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}",
"deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<User>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<Assistant>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<Assistant>' }}{% endif %}",
"deepseek_v3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<User>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- else %}{{'<Assistant>' + message['content'] + '<tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- endif %}{%- endfor %}{{'<tool▁calls▁end><end▁of▁sentence>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<tool▁outputs▁end>' + message['content'] + '<end▁of▁sentence>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<Assistant>' + content + '<end▁of▁sentence>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<tool▁outputs▁begin><tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<tool▁outputs▁end>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<Assistant>'}}{% endif %}",

View File

@@ -12,6 +12,8 @@ from transformers.utils import ModelOutput
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn,
update_ring_attn_params,
)
@@ -236,6 +238,12 @@ class SequenceParallelContextManager:
ring_attn_func=self.ring_attn_func,
)
# Patches for accelerate functionality
patch_prepare_data_loader()
patch_prepare_device_mesh(
sequence_parallel_degree=self.sequence_parallel_degree
)
def _register_model_hooks(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):

View File

@@ -524,24 +524,13 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
Merged dataset.
"""
if len(datasets) == 1:
ds = datasets[0]
# Do not shuffle if curriculum sampling is enabled
if cfg.curriculum_sampling:
return ds
return ds.shuffle(seed=cfg.seed)
return datasets[0]
LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets)
if cfg.shuffle_merged_datasets:
LOG.debug("Shuffling merged datasets...")
if cfg.curriculum_sampling:
LOG.warning(
"Shuffling merged datasets with curriculum sampling is not recommended. "
"This will randomize the order of samples."
)
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
else:
LOG.debug("Not shuffling merged datasets.")

View File

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
import numpy as np
from huggingface_hub import hf_hub_download
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from torch import Tensor
from transformers.utils import PaddingStrategy
@@ -251,13 +251,10 @@ class HFMistralTokenizer:
token_ids = [token_ids]
if skip_special_tokens:
return self._mistral.instruct_tokenizer.tokenizer.decode(
token_ids, special_token_policy=SpecialTokenPolicy.IGNORE
)
return self._mistral.instruct_tokenizer.tokenizer.decode(token_ids)
return self._mistral.instruct_tokenizer.tokenizer.decode(
token_ids, special_token_policy=SpecialTokenPolicy.KEEP
)
# to_string returns a string with special tokens
return self._mistral.instruct_tokenizer.tokenizer.to_string(token_ids)
def _create_mistral_chat_completion_request(
self, conversation: list[dict], tools: list[dict] | None = None

View File

@@ -146,7 +146,6 @@ class AxolotlInputConfig(
dpo_label_smoothing: float | None = None
dpo_norm_loss: bool | None = None
dpo_padding_free: bool | None = None
dpo_generate_during_eval: bool | None = None
datasets: (
Annotated[

View File

@@ -48,8 +48,6 @@ class ChatTemplate(str, Enum):
llama4 = "llama4"
phi_3 = "phi_3"
phi_35 = "phi_35"
phi_35_vl = "phi_35_vl"
phi_4 = "phi_4"
deepseek_v2 = "deepseek_v2"
deepseek_v3 = "deepseek_v3"
jamba = "jamba"

View File

@@ -4,14 +4,12 @@ shared pytest fixtures
import functools
import importlib
import logging
import os
import shutil
import sys
import tempfile
import time
from pathlib import Path, PosixPath
from typing import Generator
from pathlib import Path
import datasets
import pytest
@@ -26,8 +24,6 @@ from tests.hf_offline_utils import (
hf_offline_context,
)
logging.getLogger("filelock").setLevel(logging.CRITICAL)
def retry_on_request_exceptions(max_retries=3, delay=1):
# pylint: disable=duplicate-code
@@ -415,7 +411,7 @@ def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
@pytest.fixture
def temp_dir() -> Generator[str, None, None]:
def temp_dir():
# Create a temporary directory
_temp_dir = tempfile.mkdtemp()
yield _temp_dir
@@ -423,11 +419,6 @@ def temp_dir() -> Generator[str, None, None]:
shutil.rmtree(_temp_dir)
@pytest.fixture(scope="function", autouse=True)
def unique_triton_cache_dir(temp_dir: str | PosixPath) -> None:
os.environ["TRITON_CACHE_DIR"] = str(temp_dir) + "/.triton/cache"
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers import Trainer

View File

@@ -54,7 +54,6 @@ class TestSequenceParallelism:
"micro_batch_size": micro_batch_size,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",

View File

@@ -54,7 +54,6 @@ class TestPackedFlex:
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",

View File

@@ -309,7 +309,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
@@ -401,7 +400,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",

View File

@@ -38,13 +38,12 @@ class TestMultiGPUEval:
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.05,
"val_set_size": 0.004,
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca",
"split": "train[:5%]",
},
],
"num_epochs": 1,
@@ -52,7 +51,6 @@ class TestMultiGPUEval:
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
@@ -109,13 +107,12 @@ class TestMultiGPUEval:
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.01,
"val_set_size": 0.0004,
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca",
"split": "train[:5%]",
},
],
"num_epochs": 1,
@@ -123,7 +120,6 @@ class TestMultiGPUEval:
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",

View File

@@ -64,7 +64,6 @@ class TestMultiGPUGemma3:
},
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.0001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",

View File

@@ -62,7 +62,6 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
@@ -128,7 +127,6 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": gradient_accumulation_steps,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
@@ -202,7 +200,6 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -281,7 +278,6 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -344,7 +340,6 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": gradient_accumulation_steps,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
@@ -417,7 +412,6 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
@@ -497,7 +491,6 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_8bit",
"lr_scheduler": "cosine",
@@ -580,7 +573,6 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
@@ -677,7 +669,6 @@ class TestMultiGPULlama:
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
@@ -752,7 +743,6 @@ class TestMultiGPULlama:
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
@@ -827,7 +817,6 @@ class TestMultiGPULlama:
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",

View File

@@ -46,7 +46,6 @@ class TestMultiGPUQwen2:
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",

View File

@@ -48,7 +48,6 @@ class TestMultiGPURay:
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
@@ -108,7 +107,6 @@ class TestMultiGPURay:
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",

View File

@@ -396,7 +396,7 @@ def test_model_architecture(model_config):
# pylint: disable=duplicate-code
def test_kernel_training_integration(temp_dir):
def test_kernel_training_integration():
"""Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer
@@ -426,14 +426,6 @@ def test_kernel_training_integration(temp_dir):
}
)
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Load model
model, _, _ = load_model_and_tokenizer(cfg=cfg)
@@ -513,7 +505,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
assert found_patched_attn
def test_kernel_training_integration_dropout_non_zero(temp_dir):
def test_kernel_training_integration_dropout_non_zero():
"""Test model loading with dropout non-zero should not patch."""
from axolotl.cli.utils import load_model_and_tokenizer
@@ -541,14 +533,6 @@ def test_kernel_training_integration_dropout_non_zero(temp_dir):
}
)
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Get original attention class
attention_cls = get_attention_cls_from_config(cfg)