Compare commits

...

13 Commits

Author SHA1 Message Date
NanoCode012
1f2f285173 fix: missing key in enum 2025-07-03 13:46:16 +08:00
NanoCode012
98e912e416 feat: add custom processing strategy for phi35 vl 2025-07-03 13:46:16 +08:00
NanoCode012
e1528fb381 feat: add phi_35_vl support 2025-07-03 13:46:16 +08:00
NanoCode012
8ae5a2311b feat: update handling for mistraltokenizer decode and multiprocessing pickling fix (#2790)
* feat: update handling for mistraltokenizer decode

* fix: update mistral common package version

* fix: to use correct release

* fix triton path

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-07-02 08:07:18 -04:00
NanoCode012
6383630155 Fix: tokenize stall due to not shuffling dataset (#2845)
* fix: shuffle dataset even if only one to fix tokenize stall

* fix: warn if shuffling merged with curriculum sampling

* chore: refactor
2025-07-02 08:06:00 -04:00
Vincenzo di Cicco
f2b352f2e5 Add sample_packing_sequentially to trainer args (#2853) [skip ci] 2025-07-02 08:05:35 -04:00
NanoCode012
bf5928d0ee feat(doc): update docker tag examples (#2851) [skip ci]
* feat(doc): update docker tag examples

* chore: comment
2025-07-02 08:05:01 -04:00
Dhruv Mullick
d1224db8f4 Decouple generate_during_eval from wandb to support other visualizers (#2849) [skip ci]
* Add generate_during_eval for mlflow for dpo

* Decouple generate_during_eval from wandb
2025-07-02 08:04:40 -04:00
mhenrichsen
327b4e48e9 Add installation instructions for pip and Docker to README.md (#2854)
* Add installation instructions for pip and Docker to README.md

* Enhance README.md with Docker installation guidance for improved setup reliability.
2025-07-02 09:03:52 +02:00
Dan Saunders
35fdbce102 Ensure device mesh patching is applied (#2842)
* move patches; make patch stronger

* fix broken tests

* guard sequence_parallel_degree comparison against none

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-06-29 22:16:32 -04:00
Wing Lian
cb811f8bf1 upgrade to flash-attn 2.8.0.post2 (#2828)
* upgrade to flash-attn 2.8.0.post2

* use cu126 with torch 2.6

* seems vllm 0.8.5.post1 not compatible with cuda12.6.3 and torch 2.6

* cu126 + torch 2.6 as the default

* use cu126 for multigpu w torch 2.6 too

* drop vllm for now from ci for now
2025-06-29 22:11:16 -04:00
Wing Lian
7563e1bd30 set a different triton cache for each test to avoid blocking writes to cache (#2843)
* set a different triton cache for each test to avoid blocking writes to cache

* set log level

* disable debug logging for filelock
2025-06-29 22:05:21 -04:00
Wing Lian
81893c775c Accelerate 1.8.1 and BNB 0.46.0 update (#2815)
* update accelerate to v1.8.0

* update bnb also

* fix multigpu ci timeout

* fix test set size

* use latest accelerate 1.8.1

* disable default dtype
2025-06-28 15:29:19 -04:00
34 changed files with 190 additions and 63 deletions

View File

@@ -20,12 +20,11 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
- cuda: 124 - cuda: 126
cuda_version: 12.4.1 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: vllm axolotl_extras: vllm
is_latest: true
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
@@ -88,8 +87,8 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
- cuda: 124 - cuda: 126
cuda_version: 12.4.1 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: axolotl_extras:
@@ -146,8 +145,8 @@ jobs:
strategy: strategy:
matrix: matrix:
include: include:
- cuda: 124 - cuda: 126
cuda_version: 12.4.1 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: axolotl_extras:

View File

@@ -26,11 +26,11 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 124 - cuda: 126
cuda_version: 12.4.1 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: vllm axolotl_extras:
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"
- cuda: 124 - cuda: 124

View File

@@ -195,12 +195,12 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 124 - cuda: 126
cuda_version: 12.4.1 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
num_gpus: 1 num_gpus: 1
axolotl_extras: vllm axolotl_extras:
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
@@ -247,8 +247,8 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 124 - cuda: 126
cuda_version: 12.4.1 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
num_gpus: 1 num_gpus: 1
@@ -311,7 +311,7 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
num_gpus: 1 num_gpus: 1
axolotl_extras: vllm axolotl_extras:
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4

View File

@@ -59,6 +59,8 @@ Features:
### Installation ### Installation
#### Using pip
```bash ```bash
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed] pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
@@ -68,6 +70,13 @@ axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL axolotl fetch deepspeed_configs # OPTIONAL
``` ```
#### Using Docker
Installing with Docker can be less error prone than installing in your own environment.
```bash
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
```
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html). Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
### Your First Fine-tune ### Your First Fine-tune

View File

@@ -32,6 +32,8 @@ df_args = {
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""), "NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""), "CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub", "HF_HOME": "/workspace/data/huggingface-cache/hub",
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
} }
dockerfile_contents = df_template.render(**df_args) dockerfile_contents = df_template.render(**df_args)

View File

@@ -37,7 +37,3 @@ RUN git lfs install --skip-repo && \
pip3 install awscli && \ pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working # The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
pip3 install flash-attn==2.7.4.post1; \
fi

View File

@@ -34,7 +34,3 @@ RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \ && uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \ && uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic && uv pip install awscli pydantic
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
uv pip install --no-build-isolation flash-attn==2.7.4.post1; \
fi

View File

@@ -9,7 +9,7 @@ format:
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai). This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
::: {.callout-important} ::: {.callout-important}
For Blackwell GPUs, please use the tags with Pytorch 2.7.1 and CUDA 12.8. For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
::: :::
## Base ## Base
@@ -34,6 +34,7 @@ Tags examples:
- `main-base-py3.11-cu128-2.7.1` - `main-base-py3.11-cu128-2.7.1`
- `main-base-py3.11-cu126-2.7.1` - `main-base-py3.11-cu126-2.7.1`
- `main-base-py3.11-cu126-2.6.0`
- `main-base-py3.11-cu124-2.6.0` - `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1` - `main-base-py3.11-cu124-2.5.1`
@@ -73,13 +74,15 @@ There may be some extra tags appended to the image, like `-vllm` which installs
Tags examples: Tags examples:
- `main-py3.11-cu126-2.7.0` - `main-py3.11-cu128-2.7.1`
- `main-py3.11-cu126-2.7.1`
- `main-py3.11-cu126-2.6.0`
- `main-py3.11-cu124-2.6.0` - `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1` - `main-py3.11-cu124-2.5.1`
- `main-latest` - `main-latest`
- `main-20250303-py3.11-cu124-2.6.0` - `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1` - `main-20250303-py3.11-cu124-2.5.1`
- `0.9.2` - `0.10.1`
## Cloud ## Cloud

View File

@@ -16,6 +16,7 @@ format:
- [Gemma-3](#sec-gemma-3) - [Gemma-3](#sec-gemma-3)
- [Qwen2-VL](#sec-qwen2-vl) - [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl) - [Qwen2.5-VL](#sec-qwen25-vl)
- [Phi3-V](#sec-phi3-v)
## Usage ## Usage
@@ -126,6 +127,15 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
chat_template: qwen2_vl # same as qwen2-vl chat_template: qwen2_vl # same as qwen2-vl
``` ```
### Phi3-V {#sec-phi3-v}
```yaml
base_model: microsoft/Phi-3.5-vision-instruct
trust_remote_code: true
chat_template: phi_35_vl
```
## Dataset Format ## Dataset Format
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format. For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS # START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.4 bitsandbytes==0.46.0
triton>=3.0.0 triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
@@ -15,7 +15,7 @@ huggingface_hub==0.32.2
peft==0.15.2 peft==0.15.2
transformers==4.52.4 transformers==4.52.4
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.7.0 accelerate==1.8.1
datasets==3.6.0 datasets==3.6.0
deepspeed>=0.17.0 deepspeed>=0.17.0
trl==0.18.2 trl==0.18.2
@@ -68,4 +68,4 @@ schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6 axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3 axolotl-contribs-mit==0.0.3
mistral-common==1.6.0 mistral-common==1.6.3

View File

@@ -111,9 +111,9 @@ def get_package_version():
extras_require = { extras_require = {
"flash-attn": ["flash-attn==2.7.4.post1"], "flash-attn": ["flash-attn==2.8.0.post2"],
"ring-flash-attn": [ "ring-flash-attn": [
"flash-attn==2.7.4.post1", "flash-attn==2.8.0.post2",
"ring-flash-attn>=0.1.4", "ring-flash-attn>=0.1.4",
"yunchang==0.6.0", "yunchang==0.6.0",
], ],

View File

@@ -253,6 +253,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["eval_sample_packing"] = bool( training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing self.cfg.eval_sample_packing
) )
if self.cfg.sample_packing_sequentially is not None:
training_arguments_kwargs["sample_packing_sequentially"] = (
self.cfg.sample_packing_sequentially
)
if self.cfg.sample_packing_bin_size is not None: if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs["sample_packing_bin_size"] = ( training_arguments_kwargs["sample_packing_bin_size"] = (
self.cfg.sample_packing_bin_size self.cfg.sample_packing_bin_size

View File

@@ -28,7 +28,7 @@ class DPOStrategy:
training_args_kwargs["max_completion_length"] = None training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_length"] = cfg.sequence_len training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_prompt_length"] = cfg.sequence_len training_args_kwargs["max_prompt_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.use_wandb training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
if cfg.dpo_use_weighting is not None: if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
if cfg.dpo_padding_free is not None: if cfg.dpo_padding_free is not None:

View File

@@ -1,6 +1,7 @@
"""Shared constants for axolotl.loaders module""" """Shared constants for axolotl.loaders module"""
from transformers import ( from transformers import (
AutoModelForCausalLM,
Gemma3ForConditionalGeneration, Gemma3ForConditionalGeneration,
Llama4ForConditionalGeneration, Llama4ForConditionalGeneration,
LlavaForConditionalGeneration, LlavaForConditionalGeneration,
@@ -18,4 +19,6 @@ MULTIMODAL_AUTO_MODEL_MAPPING = {
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration, "qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
"mistral3": Mistral3ForConditionalGeneration, "mistral3": Mistral3ForConditionalGeneration,
"gemma3": Gemma3ForConditionalGeneration, "gemma3": Gemma3ForConditionalGeneration,
# phi3_v modeling code is not available in transformers yet
"phi3_v": AutoModelForCausalLM,
} }

View File

@@ -65,6 +65,7 @@ class PatchManager:
self._apply_mistral_cross_entropy_patch() self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch() self._apply_self_attention_lora_patch()
self._apply_gemma3_conditional_generation_forward_patch() self._apply_gemma3_conditional_generation_forward_patch()
self._apply_sequence_parallel_patches()
def apply_post_model_load_patches(self, model: PreTrainedModel): def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance.""" """Apply patches that require the model instance."""
@@ -231,6 +232,17 @@ class PatchManager:
patch_gemma3_conditional_generation_forward() patch_gemma3_conditional_generation_forward()
def _apply_sequence_parallel_patches(self):
"""Apply sequence parallelism patches."""
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
from axolotl.monkeypatch.ring_attn.patch import (
patch_prepare_data_loader,
patch_prepare_device_mesh,
)
patch_prepare_data_loader()
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
def _patch_attention(self): def _patch_attention(self):
"""Apply attention-specific patches based on model type.""" """Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")): if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):

View File

@@ -152,7 +152,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
def patch_prepare_data_loader(): def patch_prepare_data_loader():
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree. """Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
Raies: Raises:
RuntimeError: If source code to patch does not exist. RuntimeError: If source code to patch does not exist.
""" """
original_fn = accelerate.data_loader.prepare_data_loader original_fn = accelerate.data_loader.prepare_data_loader
@@ -168,23 +168,34 @@ def patch_prepare_data_loader():
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
) )
items_to_import = []
for item in dir(accelerate.data_loader):
if item in patched_source:
items_to_import.append(item)
# Create a new function from the patched source # Create a new function from the patched source
namespace = {} namespace = {}
exec( # pylint: disable=exec-used # nosec B102 exec( # pylint: disable=exec-used # nosec B102
patched_source, accelerate.data_loader.__dict__, namespace f"from accelerate.data_loader import ({', '.join(items_to_import)})",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
patched_source, globals(), namespace
) )
patched_function = namespace["prepare_data_loader"]
accelerate.data_loader.prepare_data_loader = patched_function patched_function = namespace["prepare_data_loader"]
original_fn.__code__ = patched_function.__code__
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support") LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
def patch_prepare_device_mesh(sequence_parallel_degree: int): def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh """Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
that includes sequence parallelism with the specified degree. that includes sequence parallelism with the specified degree.
Args: Args:
sequence_parallel_degree (int): The degree of sequence parallelism to use. sequence_parallel_degree: The degree of sequence parallelism to use.
fsdp: Whether to use FSDP.
""" """
def _prepare_device_mesh(self): def _prepare_device_mesh(self):
@@ -207,12 +218,14 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int):
) )
device_ids = list(range(world_size)) device_ids = list(range(world_size))
# Note that we use "cp" instead of "sp" to match the PyTorch native "context # NOTE: We use "cp" instead of "sp" to match the PyTorch native "context
# parallelism" implementation naming # parallelism" implementation naming.
# NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we
# only use "fsdp" and "cp" for the device mesh.
return dist.DeviceMesh( return dist.DeviceMesh(
"cuda", "cuda",
torch.tensor(device_ids).reshape(mesh_shape), torch.tensor(device_ids).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"), mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"),
) )
# Replace the original method with our new method # Replace the original method with our new method

View File

@@ -264,6 +264,23 @@ class Gemma3ProcessingStrategy(ProcessingStrategy):
return labels return labels
class Phi35VLProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Phi-3.5-vision-instruct"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.image_token = "<|image|>" # nosec
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
self.image_token
)
def get_processing_strategy( def get_processing_strategy(
processor: ProcessorMixin, processor: ProcessorMixin,
chat_template, chat_template,
@@ -279,6 +296,10 @@ def get_processing_strategy(
return Gemma3ProcessingStrategy( return Gemma3ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm processor, chat_template, image_size, image_resize_algorithm
) )
if chat_template_type == "phi_35_vl":
return Phi35VLProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
if chat_template_type in [ if chat_template_type in [
"llama3_2_vision", "llama3_2_vision",
"llama4", "llama4",

View File

@@ -223,8 +223,9 @@ def execute_training(
) )
LOG.info("Starting trainer...") LOG.info("Starting trainer...")
if cfg.bf16: # TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers
torch.set_default_dtype(torch.bfloat16) # if cfg.bf16:
# torch.set_default_dtype(torch.bfloat16)
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -32,6 +32,7 @@ _CHAT_TEMPLATES = {
"llava": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}", "llava": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}",
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
"phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}", "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}",
"phi_35_vl": "{% set image_count = namespace(value=0) %}{% for message in messages %}{{'<|' + message['role'] + '|>\n' }}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% set message_images = [] %}{% set message_text = [] %}{% for chunk in message['content'] %}{% if chunk['type'] == 'image' or 'image' in chunk or 'image_url' in chunk %}{% set image_count.value = image_count.value + 1 %}{% set _ = message_images.append('<|image_' + image_count.value|string + '|>\n') %}{% elif chunk['type'] == 'text' %}{% set _ = message_text.append(chunk['text']) %}{% endif %}{% endfor %}{{ message_images | join('') }}{{ message_text | join('') }}{% endif %}{{ '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
"phi_4": "{% set system_message = 'You are Phi, a language model trained by Microsoft to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> {Thought section} </think> {Solution section}. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines:' -%}{%- if messages and messages[0]['role'] == 'system' -%}{%- set system_message = messages[0]['content'] -%}{%- set messages = messages[1:] -%}{%- endif -%}<|im_start|>system<|im_sep|>{{ system_message }}<|im_end|>{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'assistant') %}{{'<|im_start|>assistant<|im_sep|>'}}{% generation %}{{message['content'] + '<|im_end|>'}}{% endgeneration %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}", "phi_4": "{% set system_message = 'You are Phi, a language model trained by Microsoft to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> {Thought section} </think> {Solution section}. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines:' -%}{%- if messages and messages[0]['role'] == 'system' -%}{%- set system_message = messages[0]['content'] -%}{%- set messages = messages[1:] -%}{%- endif -%}<|im_start|>system<|im_sep|>{{ system_message }}<|im_end|>{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'assistant') %}{{'<|im_start|>assistant<|im_sep|>'}}{% generation %}{{message['content'] + '<|im_end|>'}}{% endgeneration %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}",
"deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<User>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<Assistant>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<Assistant>' }}{% endif %}", "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<User>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<Assistant>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<Assistant>' }}{% endif %}",
"deepseek_v3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<User>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- else %}{{'<Assistant>' + message['content'] + '<tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- endif %}{%- endfor %}{{'<tool▁calls▁end><end▁of▁sentence>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<tool▁outputs▁end>' + message['content'] + '<end▁of▁sentence>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<Assistant>' + content + '<end▁of▁sentence>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<tool▁outputs▁begin><tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<tool▁outputs▁end>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<Assistant>'}}{% endif %}", "deepseek_v3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<User>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- else %}{{'<Assistant>' + message['content'] + '<tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- endif %}{%- endfor %}{{'<tool▁calls▁end><end▁of▁sentence>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<tool▁outputs▁end>' + message['content'] + '<end▁of▁sentence>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<Assistant>' + content + '<end▁of▁sentence>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<tool▁outputs▁begin><tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<tool▁outputs▁end>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<Assistant>'}}{% endif %}",

View File

@@ -12,8 +12,6 @@ from transformers.utils import ModelOutput
from axolotl.monkeypatch.ring_attn import ( from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group, get_ring_attn_group,
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn, register_ring_attn,
update_ring_attn_params, update_ring_attn_params,
) )
@@ -238,12 +236,6 @@ class SequenceParallelContextManager:
ring_attn_func=self.ring_attn_func, ring_attn_func=self.ring_attn_func,
) )
# Patches for accelerate functionality
patch_prepare_data_loader()
patch_prepare_device_mesh(
sequence_parallel_degree=self.sequence_parallel_degree
)
def _register_model_hooks(self): def _register_model_hooks(self):
# Forward pre-hook to apply sequence parallelism # Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs): def sequence_parallel_pre_hook(_, args, kwargs):

View File

@@ -524,13 +524,24 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
Merged dataset. Merged dataset.
""" """
if len(datasets) == 1: if len(datasets) == 1:
return datasets[0] ds = datasets[0]
# Do not shuffle if curriculum sampling is enabled
if cfg.curriculum_sampling:
return ds
return ds.shuffle(seed=cfg.seed)
LOG.info("Merging datasets...") LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets) merged_dataset = concatenate_datasets(datasets)
if cfg.shuffle_merged_datasets: if cfg.shuffle_merged_datasets:
LOG.debug("Shuffling merged datasets...") LOG.debug("Shuffling merged datasets...")
if cfg.curriculum_sampling:
LOG.warning(
"Shuffling merged datasets with curriculum sampling is not recommended. "
"This will randomize the order of samples."
)
merged_dataset = merged_dataset.shuffle(seed=cfg.seed) merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
else: else:
LOG.debug("Not shuffling merged datasets.") LOG.debug("Not shuffling merged datasets.")

View File

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
import numpy as np import numpy as np
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import Tekkenizer from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer
from torch import Tensor from torch import Tensor
from transformers.utils import PaddingStrategy from transformers.utils import PaddingStrategy
@@ -251,10 +251,13 @@ class HFMistralTokenizer:
token_ids = [token_ids] token_ids = [token_ids]
if skip_special_tokens: if skip_special_tokens:
return self._mistral.instruct_tokenizer.tokenizer.decode(token_ids) return self._mistral.instruct_tokenizer.tokenizer.decode(
token_ids, special_token_policy=SpecialTokenPolicy.IGNORE
)
# to_string returns a string with special tokens return self._mistral.instruct_tokenizer.tokenizer.decode(
return self._mistral.instruct_tokenizer.tokenizer.to_string(token_ids) token_ids, special_token_policy=SpecialTokenPolicy.KEEP
)
def _create_mistral_chat_completion_request( def _create_mistral_chat_completion_request(
self, conversation: list[dict], tools: list[dict] | None = None self, conversation: list[dict], tools: list[dict] | None = None

View File

@@ -146,6 +146,7 @@ class AxolotlInputConfig(
dpo_label_smoothing: float | None = None dpo_label_smoothing: float | None = None
dpo_norm_loss: bool | None = None dpo_norm_loss: bool | None = None
dpo_padding_free: bool | None = None dpo_padding_free: bool | None = None
dpo_generate_during_eval: bool | None = None
datasets: ( datasets: (
Annotated[ Annotated[

View File

@@ -48,6 +48,8 @@ class ChatTemplate(str, Enum):
llama4 = "llama4" llama4 = "llama4"
phi_3 = "phi_3" phi_3 = "phi_3"
phi_35 = "phi_35" phi_35 = "phi_35"
phi_35_vl = "phi_35_vl"
phi_4 = "phi_4"
deepseek_v2 = "deepseek_v2" deepseek_v2 = "deepseek_v2"
deepseek_v3 = "deepseek_v3" deepseek_v3 = "deepseek_v3"
jamba = "jamba" jamba = "jamba"

View File

@@ -4,12 +4,14 @@ shared pytest fixtures
import functools import functools
import importlib import importlib
import logging
import os import os
import shutil import shutil
import sys import sys
import tempfile import tempfile
import time import time
from pathlib import Path from pathlib import Path, PosixPath
from typing import Generator
import datasets import datasets
import pytest import pytest
@@ -24,6 +26,8 @@ from tests.hf_offline_utils import (
hf_offline_context, hf_offline_context,
) )
logging.getLogger("filelock").setLevel(logging.CRITICAL)
def retry_on_request_exceptions(max_retries=3, delay=1): def retry_on_request_exceptions(max_retries=3, delay=1):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -411,7 +415,7 @@ def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
@pytest.fixture @pytest.fixture
def temp_dir(): def temp_dir() -> Generator[str, None, None]:
# Create a temporary directory # Create a temporary directory
_temp_dir = tempfile.mkdtemp() _temp_dir = tempfile.mkdtemp()
yield _temp_dir yield _temp_dir
@@ -419,6 +423,11 @@ def temp_dir():
shutil.rmtree(_temp_dir) shutil.rmtree(_temp_dir)
@pytest.fixture(scope="function", autouse=True)
def unique_triton_cache_dir(temp_dir: str | PosixPath) -> None:
os.environ["TRITON_CACHE_DIR"] = str(temp_dir) + "/.triton/cache"
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches(): def cleanup_monkeypatches():
from transformers import Trainer from transformers import Trainer

View File

@@ -54,6 +54,7 @@ class TestSequenceParallelism:
"micro_batch_size": micro_batch_size, "micro_batch_size": micro_batch_size,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",

View File

@@ -54,6 +54,7 @@ class TestPackedFlex:
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"gradient_checkpointing": True, "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",

View File

@@ -309,6 +309,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"warmup_steps": 10, "warmup_steps": 10,
"val_set_size": 0.0, "val_set_size": 0.0,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.0001, "learning_rate": 0.0001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -400,6 +401,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"warmup_steps": 10, "warmup_steps": 10,
"val_set_size": 0.0, "val_set_size": 0.0,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.0001, "learning_rate": 0.0001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",

View File

@@ -38,12 +38,13 @@ class TestMultiGPUEval:
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"], "lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.004, "val_set_size": 0.05,
"special_tokens": {"pad_token": "<|endoftext|>"}, "special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [ "datasets": [
{ {
"path": "teknium/GPT4-LLM-Cleaned", "path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca", "type": "alpaca",
"split": "train[:5%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
@@ -51,6 +52,7 @@ class TestMultiGPUEval:
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -107,12 +109,13 @@ class TestMultiGPUEval:
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"], "lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.0004, "val_set_size": 0.01,
"special_tokens": {"pad_token": "<|endoftext|>"}, "special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [ "datasets": [
{ {
"path": "teknium/GPT4-LLM-Cleaned", "path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca", "type": "alpaca",
"split": "train[:5%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
@@ -120,6 +123,7 @@ class TestMultiGPUEval:
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",

View File

@@ -64,6 +64,7 @@ class TestMultiGPUGemma3:
}, },
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.0001, "learning_rate": 0.0001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",

View File

@@ -62,6 +62,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
# "gradient_checkpointing": True, # "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -127,6 +128,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
# "gradient_checkpointing": True, # "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -200,6 +202,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
# "gradient_checkpointing": True, # "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"warmup_steps": 0, "warmup_steps": 0,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
@@ -278,6 +281,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
# "gradient_checkpointing": True, # "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"warmup_steps": 0, "warmup_steps": 0,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
@@ -340,6 +344,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
# "gradient_checkpointing": True, # "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -412,6 +417,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
# "gradient_checkpointing": True, # "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -491,6 +497,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"gradient_checkpointing": True, "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_8bit", "optimizer": "adamw_torch_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -573,6 +580,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
# "gradient_checkpointing": True, # "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -669,6 +677,7 @@ class TestMultiGPULlama:
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -743,6 +752,7 @@ class TestMultiGPULlama:
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -817,6 +827,7 @@ class TestMultiGPULlama:
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",

View File

@@ -46,6 +46,7 @@ class TestMultiGPUQwen2:
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",

View File

@@ -48,6 +48,7 @@ class TestMultiGPURay:
"micro_batch_size": 4, "micro_batch_size": 4,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -107,6 +108,7 @@ class TestMultiGPURay:
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",

View File

@@ -396,7 +396,7 @@ def test_model_architecture(model_config):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
def test_kernel_training_integration(): def test_kernel_training_integration(temp_dir):
"""Test model loading with kernel patches enabled.""" """Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer from axolotl.cli.utils import load_model_and_tokenizer
@@ -426,6 +426,14 @@ def test_kernel_training_integration():
} }
) )
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Load model # Load model
model, _, _ = load_model_and_tokenizer(cfg=cfg) model, _, _ = load_model_and_tokenizer(cfg=cfg)
@@ -505,7 +513,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
assert found_patched_attn assert found_patched_attn
def test_kernel_training_integration_dropout_non_zero(): def test_kernel_training_integration_dropout_non_zero(temp_dir):
"""Test model loading with dropout non-zero should not patch.""" """Test model loading with dropout non-zero should not patch."""
from axolotl.cli.utils import load_model_and_tokenizer from axolotl.cli.utils import load_model_and_tokenizer
@@ -533,6 +541,14 @@ def test_kernel_training_integration_dropout_non_zero():
} }
) )
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Get original attention class # Get original attention class
attention_cls = get_attention_cls_from_config(cfg) attention_cls = get_attention_cls_from_config(cfg)