Compare commits

..

2 Commits

Author SHA1 Message Date
mhenrhcsen
8eba033dc4 fix: correct attention class retrieval for gemma3n model in lora_kernels.py 2025-06-27 19:30:09 +02:00
mhenrhcsen
a9c0f43202 fix: update attention class import logic for gemma3n model 2025-06-27 19:27:36 +02:00
35 changed files with 68 additions and 191 deletions

View File

@@ -20,11 +20,12 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
- cuda: 126 - cuda: 124
cuda_version: 12.6.3 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: vllm axolotl_extras: vllm
is_latest: true
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
@@ -87,8 +88,8 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
- cuda: 126 - cuda: 124
cuda_version: 12.6.3 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: axolotl_extras:
@@ -145,8 +146,8 @@ jobs:
strategy: strategy:
matrix: matrix:
include: include:
- cuda: 126 - cuda: 124
cuda_version: 12.6.3 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: axolotl_extras:

View File

@@ -26,11 +26,11 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 126 - cuda: 124
cuda_version: 12.6.3 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: axolotl_extras: vllm
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"
- cuda: 124 - cuda: 124

View File

@@ -195,12 +195,12 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 126 - cuda: 124
cuda_version: 12.6.3 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras: vllm
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
@@ -247,8 +247,8 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 126 - cuda: 124
cuda_version: 12.6.3 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
num_gpus: 1 num_gpus: 1
@@ -311,7 +311,7 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras: vllm
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4

View File

@@ -59,8 +59,6 @@ Features:
### Installation ### Installation
#### Using pip
```bash ```bash
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed] pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
@@ -70,13 +68,6 @@ axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL axolotl fetch deepspeed_configs # OPTIONAL
``` ```
#### Using Docker
Installing with Docker can be less error prone than installing in your own environment.
```bash
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
```
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html). Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
### Your First Fine-tune ### Your First Fine-tune

View File

@@ -32,8 +32,6 @@ df_args = {
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""), "NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""), "CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub", "HF_HOME": "/workspace/data/huggingface-cache/hub",
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
} }
dockerfile_contents = df_template.render(**df_args) dockerfile_contents = df_template.render(**df_args)

View File

@@ -37,3 +37,7 @@ RUN git lfs install --skip-repo && \
pip3 install awscli && \ pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working # The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
pip3 install flash-attn==2.7.4.post1; \
fi

View File

@@ -34,3 +34,7 @@ RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \ && uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \ && uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic && uv pip install awscli pydantic
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
uv pip install --no-build-isolation flash-attn==2.7.4.post1; \
fi

View File

@@ -9,7 +9,7 @@ format:
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai). This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
::: {.callout-important} ::: {.callout-important}
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8. For Blackwell GPUs, please use the tags with Pytorch 2.7.1 and CUDA 12.8.
::: :::
## Base ## Base
@@ -34,7 +34,6 @@ Tags examples:
- `main-base-py3.11-cu128-2.7.1` - `main-base-py3.11-cu128-2.7.1`
- `main-base-py3.11-cu126-2.7.1` - `main-base-py3.11-cu126-2.7.1`
- `main-base-py3.11-cu126-2.6.0`
- `main-base-py3.11-cu124-2.6.0` - `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1` - `main-base-py3.11-cu124-2.5.1`
@@ -74,15 +73,13 @@ There may be some extra tags appended to the image, like `-vllm` which installs
Tags examples: Tags examples:
- `main-py3.11-cu128-2.7.1` - `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu126-2.7.1`
- `main-py3.11-cu126-2.6.0`
- `main-py3.11-cu124-2.6.0` - `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1` - `main-py3.11-cu124-2.5.1`
- `main-latest` - `main-latest`
- `main-20250303-py3.11-cu124-2.6.0` - `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1` - `main-20250303-py3.11-cu124-2.5.1`
- `0.10.1` - `0.9.2`
## Cloud ## Cloud

View File

@@ -16,7 +16,6 @@ format:
- [Gemma-3](#sec-gemma-3) - [Gemma-3](#sec-gemma-3)
- [Qwen2-VL](#sec-qwen2-vl) - [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl) - [Qwen2.5-VL](#sec-qwen25-vl)
- [Phi3-V](#sec-phi3-v)
## Usage ## Usage
@@ -127,15 +126,6 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
chat_template: qwen2_vl # same as qwen2-vl chat_template: qwen2_vl # same as qwen2-vl
``` ```
### Phi3-V {#sec-phi3-v}
```yaml
base_model: microsoft/Phi-3.5-vision-instruct
trust_remote_code: true
chat_template: phi_35_vl
```
## Dataset Format ## Dataset Format
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format. For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS # START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.46.0 bitsandbytes==0.45.4
triton>=3.0.0 triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
@@ -15,7 +15,7 @@ huggingface_hub==0.32.2
peft==0.15.2 peft==0.15.2
transformers==4.52.4 transformers==4.52.4
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.8.1 accelerate==1.7.0
datasets==3.6.0 datasets==3.6.0
deepspeed>=0.17.0 deepspeed>=0.17.0
trl==0.18.2 trl==0.18.2
@@ -68,4 +68,4 @@ schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6 axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3 axolotl-contribs-mit==0.0.3
mistral-common==1.6.3 mistral-common==1.6.0

View File

@@ -111,9 +111,9 @@ def get_package_version():
extras_require = { extras_require = {
"flash-attn": ["flash-attn==2.8.0.post2"], "flash-attn": ["flash-attn==2.7.4.post1"],
"ring-flash-attn": [ "ring-flash-attn": [
"flash-attn==2.8.0.post2", "flash-attn==2.7.4.post1",
"ring-flash-attn>=0.1.4", "ring-flash-attn>=0.1.4",
"yunchang==0.6.0", "yunchang==0.6.0",
], ],

View File

@@ -253,10 +253,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["eval_sample_packing"] = bool( training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing self.cfg.eval_sample_packing
) )
if self.cfg.sample_packing_sequentially is not None:
training_arguments_kwargs["sample_packing_sequentially"] = (
self.cfg.sample_packing_sequentially
)
if self.cfg.sample_packing_bin_size is not None: if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs["sample_packing_bin_size"] = ( training_arguments_kwargs["sample_packing_bin_size"] = (
self.cfg.sample_packing_bin_size self.cfg.sample_packing_bin_size

View File

@@ -28,7 +28,7 @@ class DPOStrategy:
training_args_kwargs["max_completion_length"] = None training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_length"] = cfg.sequence_len training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_prompt_length"] = cfg.sequence_len training_args_kwargs["max_prompt_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval training_args_kwargs["generate_during_eval"] = cfg.use_wandb
if cfg.dpo_use_weighting is not None: if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
if cfg.dpo_padding_free is not None: if cfg.dpo_padding_free is not None:

View File

@@ -1,7 +1,6 @@
"""Shared constants for axolotl.loaders module""" """Shared constants for axolotl.loaders module"""
from transformers import ( from transformers import (
AutoModelForCausalLM,
Gemma3ForConditionalGeneration, Gemma3ForConditionalGeneration,
Llama4ForConditionalGeneration, Llama4ForConditionalGeneration,
LlavaForConditionalGeneration, LlavaForConditionalGeneration,
@@ -19,6 +18,4 @@ MULTIMODAL_AUTO_MODEL_MAPPING = {
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration, "qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
"mistral3": Mistral3ForConditionalGeneration, "mistral3": Mistral3ForConditionalGeneration,
"gemma3": Gemma3ForConditionalGeneration, "gemma3": Gemma3ForConditionalGeneration,
# phi3_v modeling code is not available in transformers yet
"phi3_v": AutoModelForCausalLM,
} }

View File

@@ -65,7 +65,6 @@ class PatchManager:
self._apply_mistral_cross_entropy_patch() self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch() self._apply_self_attention_lora_patch()
self._apply_gemma3_conditional_generation_forward_patch() self._apply_gemma3_conditional_generation_forward_patch()
self._apply_sequence_parallel_patches()
def apply_post_model_load_patches(self, model: PreTrainedModel): def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance.""" """Apply patches that require the model instance."""
@@ -232,17 +231,6 @@ class PatchManager:
patch_gemma3_conditional_generation_forward() patch_gemma3_conditional_generation_forward()
def _apply_sequence_parallel_patches(self):
"""Apply sequence parallelism patches."""
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
from axolotl.monkeypatch.ring_attn.patch import (
patch_prepare_data_loader,
patch_prepare_device_mesh,
)
patch_prepare_data_loader()
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
def _patch_attention(self): def _patch_attention(self):
"""Apply attention-specific patches based on model type.""" """Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")): if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):

View File

@@ -156,8 +156,12 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
model_cls_prefix = "".join( model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")] [part.capitalize() for part in model_type.split("_")]
) )
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"]) if model_type == "gemma3n":
attention_cls = getattr(module, f"{model_cls_prefix}Attention") module = __import__(module_path, fromlist=[f"{model_cls_prefix}TextAttention"])
attention_cls = getattr(module, f"{model_cls_prefix}TextAttention")
else:
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
return attention_cls return attention_cls
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:

View File

@@ -152,7 +152,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
def patch_prepare_data_loader(): def patch_prepare_data_loader():
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree. """Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
Raises: Raies:
RuntimeError: If source code to patch does not exist. RuntimeError: If source code to patch does not exist.
""" """
original_fn = accelerate.data_loader.prepare_data_loader original_fn = accelerate.data_loader.prepare_data_loader
@@ -168,34 +168,23 @@ def patch_prepare_data_loader():
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
) )
items_to_import = []
for item in dir(accelerate.data_loader):
if item in patched_source:
items_to_import.append(item)
# Create a new function from the patched source # Create a new function from the patched source
namespace = {} namespace = {}
exec( # pylint: disable=exec-used # nosec B102 exec( # pylint: disable=exec-used # nosec B102
f"from accelerate.data_loader import ({', '.join(items_to_import)})", patched_source, accelerate.data_loader.__dict__, namespace
globals(),
) )
exec( # pylint: disable=exec-used # nosec B102
patched_source, globals(), namespace
)
patched_function = namespace["prepare_data_loader"] patched_function = namespace["prepare_data_loader"]
original_fn.__code__ = patched_function.__code__
accelerate.data_loader.prepare_data_loader = patched_function
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support") LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False): def patch_prepare_device_mesh(sequence_parallel_degree: int):
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh """Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
that includes sequence parallelism with the specified degree. that includes sequence parallelism with the specified degree.
Args: Args:
sequence_parallel_degree: The degree of sequence parallelism to use. sequence_parallel_degree (int): The degree of sequence parallelism to use.
fsdp: Whether to use FSDP.
""" """
def _prepare_device_mesh(self): def _prepare_device_mesh(self):
@@ -218,14 +207,12 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False)
) )
device_ids = list(range(world_size)) device_ids = list(range(world_size))
# NOTE: We use "cp" instead of "sp" to match the PyTorch native "context # Note that we use "cp" instead of "sp" to match the PyTorch native "context
# parallelism" implementation naming. # parallelism" implementation naming
# NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we
# only use "fsdp" and "cp" for the device mesh.
return dist.DeviceMesh( return dist.DeviceMesh(
"cuda", "cuda",
torch.tensor(device_ids).reshape(mesh_shape), torch.tensor(device_ids).reshape(mesh_shape),
mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"), mesh_dim_names=("dp", "cp"),
) )
# Replace the original method with our new method # Replace the original method with our new method

View File

@@ -264,23 +264,6 @@ class Gemma3ProcessingStrategy(ProcessingStrategy):
return labels return labels
class Phi35VLProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Phi-3.5-vision-instruct"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.image_token = "<|image|>" # nosec
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
self.image_token
)
def get_processing_strategy( def get_processing_strategy(
processor: ProcessorMixin, processor: ProcessorMixin,
chat_template, chat_template,
@@ -296,10 +279,6 @@ def get_processing_strategy(
return Gemma3ProcessingStrategy( return Gemma3ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm processor, chat_template, image_size, image_resize_algorithm
) )
if chat_template_type == "phi_35_vl":
return Phi35VLProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
if chat_template_type in [ if chat_template_type in [
"llama3_2_vision", "llama3_2_vision",
"llama4", "llama4",

View File

@@ -223,9 +223,8 @@ def execute_training(
) )
LOG.info("Starting trainer...") LOG.info("Starting trainer...")
# TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers if cfg.bf16:
# if cfg.bf16: torch.set_default_dtype(torch.bfloat16)
# torch.set_default_dtype(torch.bfloat16)
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -32,7 +32,6 @@ _CHAT_TEMPLATES = {
"llava": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}", "llava": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}",
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
"phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}", "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}",
"phi_35_vl": "{% set image_count = namespace(value=0) %}{% for message in messages %}{{'<|' + message['role'] + '|>\n' }}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% set message_images = [] %}{% set message_text = [] %}{% for chunk in message['content'] %}{% if chunk['type'] == 'image' or 'image' in chunk or 'image_url' in chunk %}{% set image_count.value = image_count.value + 1 %}{% set _ = message_images.append('<|image_' + image_count.value|string + '|>\n') %}{% elif chunk['type'] == 'text' %}{% set _ = message_text.append(chunk['text']) %}{% endif %}{% endfor %}{{ message_images | join('') }}{{ message_text | join('') }}{% endif %}{{ '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
"phi_4": "{% set system_message = 'You are Phi, a language model trained by Microsoft to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> {Thought section} </think> {Solution section}. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines:' -%}{%- if messages and messages[0]['role'] == 'system' -%}{%- set system_message = messages[0]['content'] -%}{%- set messages = messages[1:] -%}{%- endif -%}<|im_start|>system<|im_sep|>{{ system_message }}<|im_end|>{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'assistant') %}{{'<|im_start|>assistant<|im_sep|>'}}{% generation %}{{message['content'] + '<|im_end|>'}}{% endgeneration %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}", "phi_4": "{% set system_message = 'You are Phi, a language model trained by Microsoft to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> {Thought section} </think> {Solution section}. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines:' -%}{%- if messages and messages[0]['role'] == 'system' -%}{%- set system_message = messages[0]['content'] -%}{%- set messages = messages[1:] -%}{%- endif -%}<|im_start|>system<|im_sep|>{{ system_message }}<|im_end|>{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'assistant') %}{{'<|im_start|>assistant<|im_sep|>'}}{% generation %}{{message['content'] + '<|im_end|>'}}{% endgeneration %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}",
"deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<User>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<Assistant>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<Assistant>' }}{% endif %}", "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<User>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<Assistant>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<Assistant>' }}{% endif %}",
"deepseek_v3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<User>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- else %}{{'<Assistant>' + message['content'] + '<tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- endif %}{%- endfor %}{{'<tool▁calls▁end><end▁of▁sentence>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<tool▁outputs▁end>' + message['content'] + '<end▁of▁sentence>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<Assistant>' + content + '<end▁of▁sentence>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<tool▁outputs▁begin><tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<tool▁outputs▁end>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<Assistant>'}}{% endif %}", "deepseek_v3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<User>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- else %}{{'<Assistant>' + message['content'] + '<tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- endif %}{%- endfor %}{{'<tool▁calls▁end><end▁of▁sentence>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<tool▁outputs▁end>' + message['content'] + '<end▁of▁sentence>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<Assistant>' + content + '<end▁of▁sentence>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<tool▁outputs▁begin><tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<tool▁outputs▁end>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<Assistant>'}}{% endif %}",

View File

@@ -12,6 +12,8 @@ from transformers.utils import ModelOutput
from axolotl.monkeypatch.ring_attn import ( from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group, get_ring_attn_group,
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn, register_ring_attn,
update_ring_attn_params, update_ring_attn_params,
) )
@@ -236,6 +238,12 @@ class SequenceParallelContextManager:
ring_attn_func=self.ring_attn_func, ring_attn_func=self.ring_attn_func,
) )
# Patches for accelerate functionality
patch_prepare_data_loader()
patch_prepare_device_mesh(
sequence_parallel_degree=self.sequence_parallel_degree
)
def _register_model_hooks(self): def _register_model_hooks(self):
# Forward pre-hook to apply sequence parallelism # Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs): def sequence_parallel_pre_hook(_, args, kwargs):

View File

@@ -524,24 +524,13 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
Merged dataset. Merged dataset.
""" """
if len(datasets) == 1: if len(datasets) == 1:
ds = datasets[0] return datasets[0]
# Do not shuffle if curriculum sampling is enabled
if cfg.curriculum_sampling:
return ds
return ds.shuffle(seed=cfg.seed)
LOG.info("Merging datasets...") LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets) merged_dataset = concatenate_datasets(datasets)
if cfg.shuffle_merged_datasets: if cfg.shuffle_merged_datasets:
LOG.debug("Shuffling merged datasets...") LOG.debug("Shuffling merged datasets...")
if cfg.curriculum_sampling:
LOG.warning(
"Shuffling merged datasets with curriculum sampling is not recommended. "
"This will randomize the order of samples."
)
merged_dataset = merged_dataset.shuffle(seed=cfg.seed) merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
else: else:
LOG.debug("Not shuffling merged datasets.") LOG.debug("Not shuffling merged datasets.")

View File

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
import numpy as np import numpy as np
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from torch import Tensor from torch import Tensor
from transformers.utils import PaddingStrategy from transformers.utils import PaddingStrategy
@@ -251,13 +251,10 @@ class HFMistralTokenizer:
token_ids = [token_ids] token_ids = [token_ids]
if skip_special_tokens: if skip_special_tokens:
return self._mistral.instruct_tokenizer.tokenizer.decode( return self._mistral.instruct_tokenizer.tokenizer.decode(token_ids)
token_ids, special_token_policy=SpecialTokenPolicy.IGNORE
)
return self._mistral.instruct_tokenizer.tokenizer.decode( # to_string returns a string with special tokens
token_ids, special_token_policy=SpecialTokenPolicy.KEEP return self._mistral.instruct_tokenizer.tokenizer.to_string(token_ids)
)
def _create_mistral_chat_completion_request( def _create_mistral_chat_completion_request(
self, conversation: list[dict], tools: list[dict] | None = None self, conversation: list[dict], tools: list[dict] | None = None

View File

@@ -146,7 +146,6 @@ class AxolotlInputConfig(
dpo_label_smoothing: float | None = None dpo_label_smoothing: float | None = None
dpo_norm_loss: bool | None = None dpo_norm_loss: bool | None = None
dpo_padding_free: bool | None = None dpo_padding_free: bool | None = None
dpo_generate_during_eval: bool | None = None
datasets: ( datasets: (
Annotated[ Annotated[

View File

@@ -48,8 +48,6 @@ class ChatTemplate(str, Enum):
llama4 = "llama4" llama4 = "llama4"
phi_3 = "phi_3" phi_3 = "phi_3"
phi_35 = "phi_35" phi_35 = "phi_35"
phi_35_vl = "phi_35_vl"
phi_4 = "phi_4"
deepseek_v2 = "deepseek_v2" deepseek_v2 = "deepseek_v2"
deepseek_v3 = "deepseek_v3" deepseek_v3 = "deepseek_v3"
jamba = "jamba" jamba = "jamba"

View File

@@ -4,14 +4,12 @@ shared pytest fixtures
import functools import functools
import importlib import importlib
import logging
import os import os
import shutil import shutil
import sys import sys
import tempfile import tempfile
import time import time
from pathlib import Path, PosixPath from pathlib import Path
from typing import Generator
import datasets import datasets
import pytest import pytest
@@ -26,8 +24,6 @@ from tests.hf_offline_utils import (
hf_offline_context, hf_offline_context,
) )
logging.getLogger("filelock").setLevel(logging.CRITICAL)
def retry_on_request_exceptions(max_retries=3, delay=1): def retry_on_request_exceptions(max_retries=3, delay=1):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -415,7 +411,7 @@ def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
@pytest.fixture @pytest.fixture
def temp_dir() -> Generator[str, None, None]: def temp_dir():
# Create a temporary directory # Create a temporary directory
_temp_dir = tempfile.mkdtemp() _temp_dir = tempfile.mkdtemp()
yield _temp_dir yield _temp_dir
@@ -423,11 +419,6 @@ def temp_dir() -> Generator[str, None, None]:
shutil.rmtree(_temp_dir) shutil.rmtree(_temp_dir)
@pytest.fixture(scope="function", autouse=True)
def unique_triton_cache_dir(temp_dir: str | PosixPath) -> None:
os.environ["TRITON_CACHE_DIR"] = str(temp_dir) + "/.triton/cache"
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches(): def cleanup_monkeypatches():
from transformers import Trainer from transformers import Trainer

View File

@@ -54,7 +54,6 @@ class TestSequenceParallelism:
"micro_batch_size": micro_batch_size, "micro_batch_size": micro_batch_size,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",

View File

@@ -54,7 +54,6 @@ class TestPackedFlex:
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"gradient_checkpointing": True, "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",

View File

@@ -309,7 +309,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"warmup_steps": 10, "warmup_steps": 10,
"val_set_size": 0.0, "val_set_size": 0.0,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.0001, "learning_rate": 0.0001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -401,7 +400,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"warmup_steps": 10, "warmup_steps": 10,
"val_set_size": 0.0, "val_set_size": 0.0,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.0001, "learning_rate": 0.0001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",

View File

@@ -38,13 +38,12 @@ class TestMultiGPUEval:
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"], "lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.05, "val_set_size": 0.004,
"special_tokens": {"pad_token": "<|endoftext|>"}, "special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [ "datasets": [
{ {
"path": "teknium/GPT4-LLM-Cleaned", "path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca", "type": "alpaca",
"split": "train[:5%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
@@ -52,7 +51,6 @@ class TestMultiGPUEval:
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -109,13 +107,12 @@ class TestMultiGPUEval:
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"], "lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.01, "val_set_size": 0.0004,
"special_tokens": {"pad_token": "<|endoftext|>"}, "special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [ "datasets": [
{ {
"path": "teknium/GPT4-LLM-Cleaned", "path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca", "type": "alpaca",
"split": "train[:5%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
@@ -123,7 +120,6 @@ class TestMultiGPUEval:
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",

View File

@@ -64,7 +64,6 @@ class TestMultiGPUGemma3:
}, },
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.0001, "learning_rate": 0.0001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",

View File

@@ -62,7 +62,6 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
# "gradient_checkpointing": True, # "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -128,7 +127,6 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
# "gradient_checkpointing": True, # "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -202,7 +200,6 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
# "gradient_checkpointing": True, # "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"warmup_steps": 0, "warmup_steps": 0,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
@@ -281,7 +278,6 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
# "gradient_checkpointing": True, # "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"warmup_steps": 0, "warmup_steps": 0,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
@@ -344,7 +340,6 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
# "gradient_checkpointing": True, # "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -417,7 +412,6 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
# "gradient_checkpointing": True, # "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -497,7 +491,6 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"gradient_checkpointing": True, "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_8bit", "optimizer": "adamw_torch_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -580,7 +573,6 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
# "gradient_checkpointing": True, # "gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -677,7 +669,6 @@ class TestMultiGPULlama:
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -752,7 +743,6 @@ class TestMultiGPULlama:
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -827,7 +817,6 @@ class TestMultiGPULlama:
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",

View File

@@ -46,7 +46,6 @@ class TestMultiGPUQwen2:
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",

View File

@@ -48,7 +48,6 @@ class TestMultiGPURay:
"micro_batch_size": 4, "micro_batch_size": 4,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -108,7 +107,6 @@ class TestMultiGPURay:
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir, "output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",

View File

@@ -396,7 +396,7 @@ def test_model_architecture(model_config):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
def test_kernel_training_integration(temp_dir): def test_kernel_training_integration():
"""Test model loading with kernel patches enabled.""" """Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer from axolotl.cli.utils import load_model_and_tokenizer
@@ -426,14 +426,6 @@ def test_kernel_training_integration(temp_dir):
} }
) )
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Load model # Load model
model, _, _ = load_model_and_tokenizer(cfg=cfg) model, _, _ = load_model_and_tokenizer(cfg=cfg)
@@ -513,7 +505,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
assert found_patched_attn assert found_patched_attn
def test_kernel_training_integration_dropout_non_zero(temp_dir): def test_kernel_training_integration_dropout_non_zero():
"""Test model loading with dropout non-zero should not patch.""" """Test model loading with dropout non-zero should not patch."""
from axolotl.cli.utils import load_model_and_tokenizer from axolotl.cli.utils import load_model_and_tokenizer
@@ -541,14 +533,6 @@ def test_kernel_training_integration_dropout_non_zero(temp_dir):
} }
) )
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Get original attention class # Get original attention class
attention_cls = get_attention_cls_from_config(cfg) attention_cls = get_attention_cls_from_config(cfg)