diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml
index c854af9ab..91cbaf957 100644
--- a/.github/workflows/multi-gpu-e2e.yml
+++ b/.github/workflows/multi-gpu-e2e.yml
@@ -18,6 +18,13 @@ jobs:
pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
+ - cuda: 121
+ cuda_version: 12.1.1
+ python_version: "3.11"
+ pytorch: 2.3.1
+ axolotl_extras:
+ num_gpus: 2
+ nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
@@ -39,6 +46,7 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
+ echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.multigpu
diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml
new file mode 100644
index 000000000..1440efe79
--- /dev/null
+++ b/.github/workflows/tests-nightly.yml
@@ -0,0 +1,116 @@
+name: Tests Nightly against upstream main
+on:
+ workflow_dispatch:
+ schedule:
+ - cron: '0 0 * * *' # Runs at 00:00 UTC every day
+
+jobs:
+ pre-commit:
+ name: pre-commit
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v4
+ with:
+ python-version: "3.10"
+ cache: 'pip' # caching pip dependencies
+ - uses: pre-commit/action@v3.0.0
+ env:
+ SKIP: no-commit-to-branch
+
+ pytest:
+ name: PyTest
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ python_version: ["3.10", "3.11"]
+ timeout-minutes: 20
+
+ steps:
+ - name: Check out repository code
+ uses: actions/checkout@v3
+
+ - name: Setup Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python_version }}
+ cache: 'pip' # caching pip dependencies
+
+ - name: Update requirements.txt
+ run: |
+ sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
+ sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
+ sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
+ sed -i 's#^bitsandbytes.*#bitsandbytes @ git+https://github.com/bitsandbytes-foundation/bitsandbytes.git@main#' requirements.txt
+
+ - name: Install dependencies
+ run: |
+ pip3 install --upgrade pip
+ pip3 install --upgrade packaging
+ pip3 install -U -e .
+ pip3 install -r requirements-tests.txt
+
+ - name: Run tests
+ run: |
+ pytest --ignore=tests/e2e/ tests/
+
+ - name: cleanup pip cache
+ run: |
+ find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
+
+ docker-e2e-tests:
+ if: github.repository_owner == 'axolotl-ai-cloud'
+ # this job needs to be run on self-hosted GPU runners...
+ runs-on: [self-hosted, modal]
+ timeout-minutes: 60
+ needs: [pre-commit, pytest]
+
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - cuda: 121
+ cuda_version: 12.1.1
+ python_version: "3.10"
+ pytorch: 2.3.1
+ num_gpus: 1
+ axolotl_extras: mamba-ssm
+ nightly_build: "true"
+ - cuda: 121
+ cuda_version: 12.1.1
+ python_version: "3.11"
+ pytorch: 2.3.1
+ num_gpus: 1
+ axolotl_extras: mamba-ssm
+ nightly_build: "true"
+ - cuda: 124
+ cuda_version: 12.4.1
+ python_version: "3.11"
+ pytorch: 2.4.0
+ num_gpus: 1
+ axolotl_extras:
+ nightly_build: "true"
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+ - name: Install Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.10"
+ - name: Install Modal
+ run: |
+ python -m pip install --upgrade pip
+ pip install modal==0.63.64 jinja2
+ - name: Update env vars
+ run: |
+ echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
+ echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
+ echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
+ echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
+ echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
+ echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
+ echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
+ - name: Run tests job on Modal
+ run: |
+ modal run cicd.tests
diff --git a/README.md b/README.md
index a626635dc..8c70da015 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,9 @@
# Axolotl
+
+
+
+
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
Features:
@@ -22,39 +26,49 @@ Features:
## Table of Contents
-- [Introduction](#axolotl)
-- [Supported Features](#axolotl-supports)
-- [Quickstart](#quickstart-)
-- [Environment](#environment)
- - [Docker](#docker)
- - [Conda/Pip venv](#condapip-venv)
- - [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod
- - [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
- - [Windows](#windows)
- - [Mac](#mac)
- - [Google Colab](#google-colab)
- - [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- - [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
-- [Dataset](#dataset)
-- [Config](#config)
- - [Train](#train)
- - [Inference](#inference-playground)
- - [Merge LORA to Base](#merge-lora-to-base)
- - [Special Tokens](#special-tokens)
- - [All Config Options](#all-config-options)
-- Advanced Topics
- - [Multipack](./docs/multipack.qmd)
- - [RLHF & DPO](./docs/rlhf.qmd)
- - [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)
- - [Unsloth](./docs/unsloth.qmd)
-- [Common Errors](#common-errors-)
- - [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
-- [Debugging Axolotl](#debugging-axolotl)
-- [Need Help?](#need-help-)
-- [Badge](#badge-)
-- [Community Showcase](#community-showcase)
-- [Contributing](#contributing-)
-- [Sponsors](#sponsors-)
+- [Axolotl](#axolotl)
+ - [Table of Contents](#table-of-contents)
+ - [Axolotl supports](#axolotl-supports)
+ - [Quickstart β‘](#quickstart-)
+ - [Usage](#usage)
+ - [Advanced Setup](#advanced-setup)
+ - [Environment](#environment)
+ - [Docker](#docker)
+ - [Conda/Pip venv](#condapip-venv)
+ - [Cloud GPU](#cloud-gpu)
+ - [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
+ - [LambdaLabs](#lambdalabs)
+ - [GCP](#gcp)
+ - [Windows](#windows)
+ - [Mac](#mac)
+ - [Google Colab](#google-colab)
+ - [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
+ - [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
+ - [Dataset](#dataset)
+ - [Config](#config)
+ - [All Config Options](#all-config-options)
+ - [Train](#train)
+ - [Preprocess dataset](#preprocess-dataset)
+ - [Multi-GPU](#multi-gpu)
+ - [DeepSpeed](#deepspeed)
+ - [FSDP](#fsdp)
+ - [FSDP + QLoRA](#fsdp--qlora)
+ - [Weights \& Biases Logging](#weights--biases-logging)
+ - [Special Tokens](#special-tokens)
+ - [Inference Playground](#inference-playground)
+ - [Merge LORA to base](#merge-lora-to-base)
+ - [Common Errors π§°](#common-errors-)
+ - [Tokenization Mismatch b/w Inference \& Training](#tokenization-mismatch-bw-inference--training)
+ - [Debugging Axolotl](#debugging-axolotl)
+ - [Need help? π](#need-help-)
+ - [Badge β€π·οΈ](#badge-οΈ)
+ - [Community Showcase](#community-showcase)
+ - [Contributing π€](#contributing-)
+ - [Sponsors π€β€](#sponsors-)
+ - [π Diamond Sponsors - Contact directly](#-diamond-sponsors---contact-directly)
+ - [π₯ Gold Sponsors - $5000/mo](#-gold-sponsors---5000mo)
+ - [π₯ Silver Sponsors - $1000/mo](#-silver-sponsors---1000mo)
+ - [π₯ Bronze Sponsors - $500/mo](#-bronze-sponsors---500mo)
|
@@ -96,6 +110,7 @@ Features:
| RWKV | β
| β | β | β | β | β | β |
| Qwen | β
| β
| β
| β | β | β | β |
| Gemma | β
| β
| β
| β | β | β
| β |
+| Jamba | β
| β
| β
| β | β | β
| β |
β
: supported
β: not supported
diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja
index 3a7988366..c245fce3e 100644
--- a/cicd/Dockerfile.jinja
+++ b/cicd/Dockerfile.jinja
@@ -8,6 +8,7 @@ ENV BNB_CUDA_VERSION="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
+ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
@@ -23,6 +24,13 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
+RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
+ sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
+ sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
+ sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
+ sed -i 's#^bitsandbytes.*#bitsandbytes @ git+https://github.com/bitsandbytes-foundation/bitsandbytes.git@main#' requirements.txt; \
+ fi
+
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
diff --git a/cicd/tests.py b/cicd/tests.py
index c21467637..9c2d830cb 100644
--- a/cicd/tests.py
+++ b/cicd/tests.py
@@ -28,6 +28,7 @@ df_args = {
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
+ "NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
}
dockerfile_contents = df_template.render(**df_args)
diff --git a/docs/unsloth.qmd b/docs/unsloth.qmd
index 390609fd3..90cb49baf 100644
--- a/docs/unsloth.qmd
+++ b/docs/unsloth.qmd
@@ -34,7 +34,7 @@ unsloth_lora_o: true
```
These options are composable and can be used with multi-gpu finetuning
-```
+```yaml
unsloth_cross_entropy_loss: true
unsloth_rms_norm: true
unsloth_rope: true
diff --git a/examples/jamba/README.md b/examples/jamba/README.md
index 54f5d1da9..4c9dc85a0 100644
--- a/examples/jamba/README.md
+++ b/examples/jamba/README.md
@@ -6,5 +6,5 @@
- β
qlora w/ deepspeed Zero-3 needs at least 2x GPUs and 67GiB VRAM (wtf?)
- β
qlora single-gpu, ~51GiB VRAM
- β
multipack
-- β FSDP
+- β
FSDP
- β 8-bit LoRA
diff --git a/examples/jamba/qlora_fsdp_large.yaml b/examples/jamba/qlora_fsdp_large.yaml
new file mode 100644
index 000000000..28316efd5
--- /dev/null
+++ b/examples/jamba/qlora_fsdp_large.yaml
@@ -0,0 +1,61 @@
+base_model: ai21labs/AI21-Jamba-1.5-Large
+tokenizer_type: AutoTokenizer
+
+load_in_4bit: true
+strict: false
+use_tensorboard: true
+datasets:
+ - path: cgato/SlimOrcaDedupCleaned
+ type: chat_template
+ chat_template: jamba
+ drop_system_message: true
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.0
+output_dir: jamba-large-fsdp-qlora-ft
+save_safetensors: true
+adapter: qlora
+sequence_len: 2048
+sample_packing: true
+pad_to_sequence_len: true
+
+lora_r: 16
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_modules: [down_proj,gate_proj,in_proj,k_proj,o_proj,out_proj,q_proj,up_proj,v_proj,x_proj]
+lora_target_linear: false
+
+gradient_accumulation_steps: 4
+micro_batch_size: 1
+num_epochs: 2
+optimizer: adamw_torch
+lr_scheduler: cosine
+learning_rate: 0.00001
+
+train_on_inputs: false
+group_by_length: false
+bf16: true
+tf32: true
+
+gradient_checkpointing: true
+gradient_checkpointing_kwargs:
+ use_reentrant: true
+logging_steps: 1
+flash_attention: true
+
+warmup_steps: 10
+evals_per_epoch: 1
+saves_per_epoch: 1
+weight_decay: 0.0
+fsdp:
+ - full_shard
+ - auto_wrap
+fsdp_config:
+ fsdp_limit_all_gathers: true
+ fsdp_sync_module_states: true
+ fsdp_offload_params: false
+ fsdp_use_orig_params: false
+ fsdp_cpu_ram_efficient_loading: true
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
+ fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
+ fsdp_state_dict_type: FULL_STATE_DICT
+ fsdp_sharding_strategy: FULL_SHARD
diff --git a/examples/qwen2/qlora-fsdp.yaml b/examples/qwen2/qlora-fsdp.yaml
index 44f9c7e49..d61c72a37 100644
--- a/examples/qwen2/qlora-fsdp.yaml
+++ b/examples/qwen2/qlora-fsdp.yaml
@@ -72,4 +72,5 @@ fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
+ fsdp_sharding_strategy: FULL_SHARD
special_tokens:
diff --git a/examples/tiny-llama/pretrain.yml b/examples/tiny-llama/pretrain.yml
index e501dcb8e..010a1608a 100644
--- a/examples/tiny-llama/pretrain.yml
+++ b/examples/tiny-llama/pretrain.yml
@@ -9,9 +9,9 @@ strict: false
max_steps: 200
pretraining_dataset:
- path: c4
- name: en
- type: pretrain
+ - path: allenai/c4
+ name: en
+ type: pretrain
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/model-out
diff --git a/requirements.txt b/requirements.txt
index dc74b916f..be0c4927e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -21,7 +21,7 @@ optimum==1.16.2
hf_transfer
colorama
numba
-numpy>=1.24.4
+numpy>=1.24.4,<=2.0.1
# qlora things
evaluate==0.4.1
scipy
diff --git a/setup.py b/setup.py
index 1d164e0a1..1b64fadae 100644
--- a/setup.py
+++ b/setup.py
@@ -80,7 +80,7 @@ setup(
dependency_links=dependency_links,
extras_require={
"flash-attn": [
- "flash-attn==2.6.2",
+ "flash-attn==2.6.3",
],
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib",
diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py
new file mode 100644
index 000000000..25408fd57
--- /dev/null
+++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py
@@ -0,0 +1,204 @@
+"""
+This module provides a CLI to merge sharded FSDP model checkpoints into a single combined checkpoint
+"""
+import json
+import logging
+import os
+import shutil
+from pathlib import Path
+from typing import Dict, Union
+
+import fire
+import torch
+import torch.distributed.checkpoint as dist_cp
+import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
+import transformers
+from accelerate.utils import (
+ SAFE_WEIGHTS_INDEX_NAME,
+ SAFE_WEIGHTS_NAME,
+ WEIGHTS_INDEX_NAME,
+ WEIGHTS_NAME,
+ is_torch_version,
+)
+from dotenv import load_dotenv
+from huggingface_hub import split_torch_state_dict_into_shards
+from safetensors.torch import save_file as safe_save_file
+from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
+
+from axolotl.cli import load_cfg, print_axolotl_text_art
+from axolotl.common.cli import TrainerCliArgs
+
+LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights")
+
+
+class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
+ """
+ A custom planner to cast tensors to bfloat16 on the fly during loading.
+ """
+
+ def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument
+ tensor.copy_(tensor.to(torch.bfloat16))
+
+
+def _distributed_checkpoint_to_merged_weights(
+ checkpoint_dir: Union[str, Path],
+ save_path: str,
+ safe_serialization: bool = False,
+ max_shard_size: str = "5GB",
+):
+ """
+ Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`
+
+ Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
+ """
+
+ state_dict: Dict = {}
+ save_path_ = Path(save_path)
+ save_path_.mkdir(exist_ok=True)
+ dist_cp_format_utils._load_state_dict( # pylint: disable=protected-access
+ state_dict,
+ storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
+ planner=BFloat16CastPlanner(), # pylint: disable=protected-access
+ no_dist=True,
+ )
+
+ # To handle if state is a dict like {model: {...}}
+ if len(state_dict.keys()) == 1:
+ state_dict = state_dict[list(state_dict)[0]]
+
+ # Ensure all tensors are in bfloat16
+ for key, value in state_dict.items():
+ if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16:
+ state_dict[key] = value.to(torch.bfloat16)
+
+ weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
+
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
+ ".safetensors", "{suffix}.safetensors"
+ )
+ state_dict_split = split_torch_state_dict_into_shards(
+ state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
+ )
+ # Save index if sharded
+ index = None
+ if state_dict_split.is_sharded:
+ index = {
+ "metadata": state_dict_split.metadata,
+ "weight_map": state_dict_split.tensor_to_filename,
+ }
+
+ # Save the model
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
+
+ for shard_file, tensors in filename_to_tensors:
+ shard = {tensor: state_dict[tensor] for tensor in tensors}
+
+ if safe_serialization:
+ safe_save_file(
+ shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
+ )
+ else:
+ torch.save(shard, os.path.join(save_path_, shard_file))
+
+ if index is not None:
+ save_index_file = (
+ SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
+ )
+ save_index_file = os.path.join(save_path_, save_index_file)
+ # Save the index as well
+ with open(save_index_file, "w", encoding="utf-8") as fout:
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+ fout.write(content)
+
+ return save_path_
+
+
+def merge_fsdp_weights(
+ checkpoint_dir: str,
+ output_path: str,
+ safe_serialization: bool = False,
+ remove_checkpoint_dir: bool = False,
+):
+ """
+ Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
+ `SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
+ `safe_serialization` else `pytorch_model.bin`.
+
+ Note: this is a CPU-bound process.
+
+ Args:
+ checkpoint_dir (`str`):
+ The directory containing the FSDP checkpoints (can be either the model or optimizer).
+ output_path (`str`):
+ The path to save the merged checkpoint.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the merged weights with safetensors (recommended).
+ remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
+ Whether to remove the checkpoint directory after merging.
+ """
+ checkpoint_dir_ = Path(checkpoint_dir)
+ from accelerate.state import PartialState
+
+ if not is_torch_version(">=", "2.3.0"):
+ raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
+
+ # Verify that the checkpoint directory exists
+ if not checkpoint_dir_.exists():
+ model_path_exists = (checkpoint_dir_ / "pytorch_model_fsdp_0").exists()
+ optimizer_path_exists = (checkpoint_dir_ / "optimizer_0").exists()
+ err = f"Tried to load from {checkpoint_dir_} but couldn't find a valid metadata file."
+ if model_path_exists and optimizer_path_exists:
+ err += (
+ " However, potential model and optimizer checkpoint directories exist."
+ )
+ err += f"Please pass in either {checkpoint_dir_}/pytorch_model_fsdp_0 or {checkpoint_dir_}/optimizer_0"
+ err += "instead."
+ elif model_path_exists:
+ err += " However, a potential model checkpoint directory exists."
+ err += (
+ f"Please try passing in {checkpoint_dir_}/pytorch_model_fsdp_0 instead."
+ )
+ elif optimizer_path_exists:
+ err += " However, a potential optimizer checkpoint directory exists."
+ err += f"Please try passing in {checkpoint_dir_}/optimizer_0 instead."
+ raise ValueError(err)
+
+ # To setup `save` to work
+ state = PartialState()
+ if state.is_main_process:
+ LOG.info(f"Merging FSDP weights from {checkpoint_dir_}")
+ save_path = _distributed_checkpoint_to_merged_weights(
+ checkpoint_dir_, output_path, safe_serialization
+ )
+ LOG.info(f"Successfully merged FSDP weights and saved to {save_path}")
+ if remove_checkpoint_dir:
+ LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
+ shutil.rmtree(checkpoint_dir_)
+ state.wait_for_everyone()
+
+
+def do_cli(config: Path = Path("examples/"), **kwargs):
+ # pylint: disable=duplicate-code
+ print_axolotl_text_art()
+ parser = transformers.HfArgumentParser((TrainerCliArgs))
+ parsed_cli_args, _ = parser.parse_args_into_dataclasses(
+ return_remaining_strings=True
+ )
+ parsed_cli_args.merge_lora = True
+
+ parsed_cfg = load_cfg(
+ config,
+ **kwargs,
+ )
+
+ fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
+ merge_fsdp_weights(
+ checkpoint_dir=str(fsdp_dir),
+ output_path=str(Path(parsed_cfg.output_dir) / "merged"),
+ safe_serialization=True,
+ )
+
+
+if __name__ == "__main__":
+ load_dotenv()
+ fire.Fire(do_cli)
diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py
index 4e8b36905..1a073ca04 100755
--- a/src/axolotl/core/trainer_builder.py
+++ b/src/axolotl/core/trainer_builder.py
@@ -1846,6 +1846,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
)
if self.cfg.fsdp:
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
+ if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model:
+ ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py
index 904352010..44fc4cb47 100644
--- a/src/axolotl/monkeypatch/multipack.py
+++ b/src/axolotl/monkeypatch/multipack.py
@@ -17,6 +17,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"qwen2_moe",
"falcon",
"phi",
+ "phi3",
"gemma",
"gemma2",
"gemmoe",
diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py
index 11c8aba7a..8240d8a28 100644
--- a/src/axolotl/prompt_strategies/chat_template.py
+++ b/src/axolotl/prompt_strategies/chat_template.py
@@ -361,7 +361,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
- "train_on_eos": ds_cfg.get("train_on_eos", "last"),
+ "train_on_eos": ds_cfg.get("train_on_eos", "turn"),
}
strategy = ChatTemplateStrategy(
diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py
index 0ffa3e55f..13ff450f8 100644
--- a/src/axolotl/prompters.py
+++ b/src/axolotl/prompters.py
@@ -65,8 +65,10 @@ class AlpacaPrompter(Prompter):
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
elif self.prompt_style == PromptStyle.PHI.value:
self.turn_format = "<|user|>\n{instruction}<|end|>{input}<|assistant|>"
- self.turn_no_input_format = "<|user|>\n{instruction}<|end|><|assistant|>"
- self.system_format = "<|system|>{system}\n"
+ self.turn_no_input_format = (
+ "<|user|>\n{instruction}<|end|>\n<|assistant|>\n"
+ )
+ self.system_format = "<|system|>\n{system}<|end|>\n"
def _build_result(self, instruction, input_text, output):
# returns the full prompt from instruction and optional input
diff --git a/src/axolotl/train.py b/src/axolotl/train.py
index b8890d4f7..b21b0b269 100644
--- a/src/axolotl/train.py
+++ b/src/axolotl/train.py
@@ -12,6 +12,7 @@ import torch
import transformers.modelcard
from accelerate import Accelerator
from accelerate.logging import get_logger
+from accelerate.utils import save_fsdp_model
from datasets import Dataset
from peft import PeftModel
from pkg_resources import get_distribution # type: ignore
@@ -194,9 +195,12 @@ def train(
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
+ state_dict_type = "FULL_STATE_DICT"
if trainer.is_fsdp_enabled:
- trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
- LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
+ if cfg.fsdp_final_state_dict_type:
+ state_dict_type = cfg.fsdp_final_state_dict_type
+ trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
+ LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
@@ -208,7 +212,18 @@ def train(
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
- trainer.save_model(cfg.output_dir)
+ if (
+ state_dict_type == "SHARDED_STATE_DICT"
+ and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
+ ):
+ save_fsdp_model(
+ trainer.accelerator.state.fsdp_plugin,
+ trainer.accelerator,
+ trainer.model,
+ cfg.output_dir,
+ )
+ elif state_dict_type == "FULL_STATE_DICT":
+ trainer.save_model(cfg.output_dir)
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone()
diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py
index ad161a49c..55d70fd9e 100644
--- a/src/axolotl/utils/chat_templates.py
+++ b/src/axolotl/utils/chat_templates.py
@@ -23,6 +23,7 @@ _TEMPLATES = {
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
"deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<ο½Userο½>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<ο½Assistantο½>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<ο½Assistantο½>' }}{% endif %}",
+ "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n',
}
diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
index aa5eea6af..89cd36784 100644
--- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
+++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
@@ -40,6 +40,7 @@ class ChatTemplate(str, Enum):
llama3 = "llama3" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
+ jamba = "jamba" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
@@ -650,6 +651,9 @@ class AxolotlInputConfig(
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
fsdp: Optional[List[str]] = None
fsdp_config: Optional[Dict[str, Any]] = None
+ fsdp_final_state_dict_type: Optional[
+ Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
+ ] = None
val_set_size: Optional[float] = Field(default=0.0)
@@ -1186,6 +1190,20 @@ class AxolotlInputConfig(
)
return data
+ @model_validator(mode="before")
+ @classmethod
+ def check_fsdp_sharded_state_dict_w_safetensors(cls, data):
+ if (
+ data.get("fsdp")
+ and data.get("save_safetensors")
+ and data.get("fsdp_config")
+ and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
+ ):
+ raise ValueError(
+ "FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
+ )
+ return data
+
@model_validator(mode="before")
@classmethod
def check_causal_lm_evals(cls, data):
diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py
index e056c7f50..16f38218c 100644
--- a/src/axolotl/utils/data/pretraining.py
+++ b/src/axolotl/utils/data/pretraining.py
@@ -18,10 +18,10 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining(
- tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
+ tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
) -> Dict[str, List]:
res = tokenizer(
- examples,
+ examples["text"],
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index e2c4244f9..3c8feb9b4 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -547,7 +547,9 @@ def load_model(
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16,
}
- if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed:
+ if cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
+ cfg.deepspeed or cfg.fsdp
+ ):
# for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32
@@ -592,16 +594,10 @@ def load_model(
"flash_attention_2"
)
else:
- if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
- model_kwargs["attn_implementation"] = "flash_attention_2"
- model_config._attn_implementation = ( # pylint: disable=protected-access
- "flash_attention_2"
- )
- else:
- model_kwargs["attn_implementation"] = "eager"
- model_config._attn_implementation = ( # pylint: disable=protected-access
- "eager"
- )
+ model_kwargs["attn_implementation"] = "flash_attention_2"
+ model_config._attn_implementation = ( # pylint: disable=protected-access
+ "flash_attention_2"
+ )
elif cfg.sdp_attention:
model_kwargs["attn_implementation"] = "sdpa"
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
@@ -1103,9 +1099,20 @@ def load_lora(model, cfg, inference=False, config_only=False):
def ensure_dtype(model, dtype=torch.bfloat16):
for name, module in model.named_modules():
+ weight_mismatch = False
+ bias_mismatch = False
try:
- if module.weight.dtype != dtype:
- print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
- module.to(dtype)
+ weight_mismatch = module.weight.dtype != dtype
except AttributeError:
pass
+ try:
+ bias_mismatch = module.bias.dtype != dtype
+ except AttributeError:
+ pass
+
+ if weight_mismatch:
+ print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}")
+ if bias_mismatch:
+ print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}")
+ if weight_mismatch or bias_mismatch:
+ module.to(dtype)
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index 26796f2e5..99c10c655 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -399,12 +399,15 @@ def setup_torch_compile_env(cfg):
def setup_deepspeed_env(cfg, stage=None):
+ from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
+
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if stage:
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
+ HfTrainerDeepSpeedConfig(cfg.deepspeed)
def setup_fsdp_envs(cfg):
diff --git a/tests/e2e/multigpu/test_qwen2.py b/tests/e2e/multigpu/test_qwen2.py
new file mode 100644
index 000000000..2513be69e
--- /dev/null
+++ b/tests/e2e/multigpu/test_qwen2.py
@@ -0,0 +1,98 @@
+"""
+E2E tests for multigpu qwen2
+"""
+
+import logging
+import os
+import unittest
+from pathlib import Path
+
+import yaml
+from accelerate.test_utils import execute_subprocess_async
+
+from axolotl.utils.dict import DictDefault
+
+from ..utils import with_temp_dir
+
+LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
+os.environ["WANDB_DISABLED"] = "true"
+
+
+class TestMultiGPUQwen2(unittest.TestCase):
+ """
+ Test case for Llama models using LoRA
+ """
+
+ @with_temp_dir
+ def test_qlora_fsdp_dpo(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "Qwen/Qwen2-1.5B",
+ "load_in_4bit": True,
+ "rl": "dpo",
+ "chat_template": "chatml",
+ "sequence_len": 2048,
+ "adapter": "qlora",
+ "lora_r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "lora_target_linear": True,
+ "val_set_size": 0.05,
+ "datasets": [
+ {
+ "path": "Intel/orca_dpo_pairs",
+ "split": "train",
+ "type": "chatml.intel",
+ },
+ ],
+ "num_epochs": 1,
+ "max_steps": 100,
+ "warmup_steps": 20,
+ "micro_batch_size": 4,
+ "gradient_accumulation_steps": 2,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "flash_attention": True,
+ "bf16": "auto",
+ "tf32": True,
+ "gradient_checkpointing": True,
+ "gradient_checkpointing_kwargs": {
+ "use_reentrant": False,
+ },
+ "fsdp": [
+ "full_shard",
+ "auto_wrap",
+ ],
+ "fsdp_config": {
+ "fsdp_limit_all_gathers": True,
+ "fsdp_offload_params": False,
+ "fsdp_sync_module_states": True,
+ "fsdp_use_orig_params": False,
+ "fsdp_cpu_ram_efficient_loading": False,
+ "fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
+ "fsdp_state_dict_type": "FULL_STATE_DICT",
+ "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
+ "fsdp_sharding_strategy": "FULL_SHARD",
+ },
+ }
+ )
+
+ # write cfg to yaml file
+ Path(temp_dir).mkdir(parents=True, exist_ok=True)
+ with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
+ fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
+
+ execute_subprocess_async(
+ [
+ "accelerate",
+ "launch",
+ "--num-processes",
+ "2",
+ "-m",
+ "axolotl.cli.train",
+ str(Path(temp_dir) / "config.yaml"),
+ ]
+ )
diff --git a/tests/test_data.py b/tests/test_data.py
index 16af089a0..9d7f5a041 100644
--- a/tests/test_data.py
+++ b/tests/test_data.py
@@ -35,7 +35,7 @@ class TestEncodePretraining(unittest.TestCase):
"hello, hello",
]
}
- result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"])
+ result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
self.assertEqual(len(result["input_ids"]), 3)
diff --git a/tests/test_prompters.py b/tests/test_prompters.py
index 6c5b8f27c..3d61398e0 100644
--- a/tests/test_prompters.py
+++ b/tests/test_prompters.py
@@ -42,6 +42,19 @@ class AlpacaPrompterTest(unittest.TestCase):
assert "USER:" not in res
assert "ASSISTANT:" not in res
+ def test_prompt_style_w_phi(self):
+ prompter = AlpacaPrompter(prompt_style=PromptStyle.PHI.value)
+ res = next(prompter.build_prompt("tell me a joke about the following"))
+ assert (
+ """<|system|>
+Below is an instruction that describes a task. Write a response that appropriately completes the request.<|end|>
+<|user|>
+tell me a joke about the following<|end|>
+<|assistant|>
+"""
+ == res
+ )
+
def test_prompt_style_w_chat(self):
prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(
|