diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml index c854af9ab..91cbaf957 100644 --- a/.github/workflows/multi-gpu-e2e.yml +++ b/.github/workflows/multi-gpu-e2e.yml @@ -18,6 +18,13 @@ jobs: pytorch: 2.3.1 axolotl_extras: num_gpus: 2 + - cuda: 121 + cuda_version: 12.1.1 + python_version: "3.11" + pytorch: 2.3.1 + axolotl_extras: + num_gpus: 2 + nightly_build: "true" runs-on: [self-hosted, modal] timeout-minutes: 120 steps: @@ -39,6 +46,7 @@ jobs: echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV + echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV - name: Run tests job on Modal run: | modal run cicd.multigpu diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml new file mode 100644 index 000000000..1440efe79 --- /dev/null +++ b/.github/workflows/tests-nightly.yml @@ -0,0 +1,116 @@ +name: Tests Nightly against upstream main +on: + workflow_dispatch: + schedule: + - cron: '0 0 * * *' # Runs at 00:00 UTC every day + +jobs: + pre-commit: + name: pre-commit + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: "3.10" + cache: 'pip' # caching pip dependencies + - uses: pre-commit/action@v3.0.0 + env: + SKIP: no-commit-to-branch + + pytest: + name: PyTest + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python_version: ["3.10", "3.11"] + timeout-minutes: 20 + + steps: + - name: Check out repository code + uses: actions/checkout@v3 + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python_version }} + cache: 'pip' # caching pip dependencies + + - name: Update requirements.txt + run: | + sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt + sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt + sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt + sed -i 's#^bitsandbytes.*#bitsandbytes @ git+https://github.com/bitsandbytes-foundation/bitsandbytes.git@main#' requirements.txt + + - name: Install dependencies + run: | + pip3 install --upgrade pip + pip3 install --upgrade packaging + pip3 install -U -e . + pip3 install -r requirements-tests.txt + + - name: Run tests + run: | + pytest --ignore=tests/e2e/ tests/ + + - name: cleanup pip cache + run: | + find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; + + docker-e2e-tests: + if: github.repository_owner == 'axolotl-ai-cloud' + # this job needs to be run on self-hosted GPU runners... + runs-on: [self-hosted, modal] + timeout-minutes: 60 + needs: [pre-commit, pytest] + + strategy: + fail-fast: false + matrix: + include: + - cuda: 121 + cuda_version: 12.1.1 + python_version: "3.10" + pytorch: 2.3.1 + num_gpus: 1 + axolotl_extras: mamba-ssm + nightly_build: "true" + - cuda: 121 + cuda_version: 12.1.1 + python_version: "3.11" + pytorch: 2.3.1 + num_gpus: 1 + axolotl_extras: mamba-ssm + nightly_build: "true" + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.4.0 + num_gpus: 1 + axolotl_extras: + nightly_build: "true" + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Install Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Install Modal + run: | + python -m pip install --upgrade pip + pip install modal==0.63.64 jinja2 + - name: Update env vars + run: | + echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV + echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV + echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV + echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV + echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV + echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV + echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV + - name: Run tests job on Modal + run: | + modal run cicd.tests diff --git a/README.md b/README.md index a626635dc..8c70da015 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ # Axolotl +![tests](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg) +![tests-nightly](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg) +![multigpu-semi-weekly tests](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg) + Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures. Features: @@ -22,39 +26,49 @@ Features: ## Table of Contents -- [Introduction](#axolotl) -- [Supported Features](#axolotl-supports) -- [Quickstart](#quickstart-) -- [Environment](#environment) - - [Docker](#docker) - - [Conda/Pip venv](#condapip-venv) - - [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod - - [Bare Metal Cloud GPU](#bare-metal-cloud-gpu) - - [Windows](#windows) - - [Mac](#mac) - - [Google Colab](#google-colab) - - [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot) - - [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack) -- [Dataset](#dataset) -- [Config](#config) - - [Train](#train) - - [Inference](#inference-playground) - - [Merge LORA to Base](#merge-lora-to-base) - - [Special Tokens](#special-tokens) - - [All Config Options](#all-config-options) -- Advanced Topics - - [Multipack](./docs/multipack.qmd) - - [RLHF & DPO](./docs/rlhf.qmd) - - [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd) - - [Unsloth](./docs/unsloth.qmd) -- [Common Errors](#common-errors-) - - [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training) -- [Debugging Axolotl](#debugging-axolotl) -- [Need Help?](#need-help-) -- [Badge](#badge-) -- [Community Showcase](#community-showcase) -- [Contributing](#contributing-) -- [Sponsors](#sponsors-) +- [Axolotl](#axolotl) + - [Table of Contents](#table-of-contents) + - [Axolotl supports](#axolotl-supports) + - [Quickstart ⚑](#quickstart-) + - [Usage](#usage) + - [Advanced Setup](#advanced-setup) + - [Environment](#environment) + - [Docker](#docker) + - [Conda/Pip venv](#condapip-venv) + - [Cloud GPU](#cloud-gpu) + - [Bare Metal Cloud GPU](#bare-metal-cloud-gpu) + - [LambdaLabs](#lambdalabs) + - [GCP](#gcp) + - [Windows](#windows) + - [Mac](#mac) + - [Google Colab](#google-colab) + - [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot) + - [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack) + - [Dataset](#dataset) + - [Config](#config) + - [All Config Options](#all-config-options) + - [Train](#train) + - [Preprocess dataset](#preprocess-dataset) + - [Multi-GPU](#multi-gpu) + - [DeepSpeed](#deepspeed) + - [FSDP](#fsdp) + - [FSDP + QLoRA](#fsdp--qlora) + - [Weights \& Biases Logging](#weights--biases-logging) + - [Special Tokens](#special-tokens) + - [Inference Playground](#inference-playground) + - [Merge LORA to base](#merge-lora-to-base) + - [Common Errors 🧰](#common-errors-) + - [Tokenization Mismatch b/w Inference \& Training](#tokenization-mismatch-bw-inference--training) + - [Debugging Axolotl](#debugging-axolotl) + - [Need help? πŸ™‹](#need-help-) + - [Badge ❀🏷️](#badge-️) + - [Community Showcase](#community-showcase) + - [Contributing 🀝](#contributing-) + - [Sponsors 🀝❀](#sponsors-) + - [πŸ’Ž Diamond Sponsors - Contact directly](#-diamond-sponsors---contact-directly) + - [πŸ₯‡ Gold Sponsors - $5000/mo](#-gold-sponsors---5000mo) + - [πŸ₯ˆ Silver Sponsors - $1000/mo](#-silver-sponsors---1000mo) + - [πŸ₯‰ Bronze Sponsors - $500/mo](#-bronze-sponsors---500mo) @@ -96,6 +110,7 @@ Features: | RWKV | βœ… | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | | Qwen | βœ… | βœ… | βœ… | ❓ | ❓ | ❓ | ❓ | | Gemma | βœ… | βœ… | βœ… | ❓ | ❓ | βœ… | ❓ | +| Jamba | βœ… | βœ… | βœ… | ❓ | ❓ | βœ… | ❓ | βœ…: supported ❌: not supported diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 3a7988366..c245fce3e 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -8,6 +8,7 @@ ENV BNB_CUDA_VERSION="{{ CUDA }}" ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}" ENV GITHUB_REF="{{ GITHUB_REF }}" ENV GITHUB_SHA="{{ GITHUB_SHA }}" +ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}" RUN apt-get update && \ apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev @@ -23,6 +24,13 @@ RUN git fetch origin +$GITHUB_REF && \ # If AXOLOTL_EXTRAS is set, append it in brackets RUN pip install causal_conv1d +RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ + sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \ + sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \ + sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \ + sed -i 's#^bitsandbytes.*#bitsandbytes @ git+https://github.com/bitsandbytes-foundation/bitsandbytes.git@main#' requirements.txt; \ + fi + RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ diff --git a/cicd/tests.py b/cicd/tests.py index c21467637..9c2d830cb 100644 --- a/cicd/tests.py +++ b/cicd/tests.py @@ -28,6 +28,7 @@ df_args = { "CUDA": os.environ.get("CUDA", "121"), "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), + "NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""), } dockerfile_contents = df_template.render(**df_args) diff --git a/docs/unsloth.qmd b/docs/unsloth.qmd index 390609fd3..90cb49baf 100644 --- a/docs/unsloth.qmd +++ b/docs/unsloth.qmd @@ -34,7 +34,7 @@ unsloth_lora_o: true ``` These options are composable and can be used with multi-gpu finetuning -``` +```yaml unsloth_cross_entropy_loss: true unsloth_rms_norm: true unsloth_rope: true diff --git a/examples/jamba/README.md b/examples/jamba/README.md index 54f5d1da9..4c9dc85a0 100644 --- a/examples/jamba/README.md +++ b/examples/jamba/README.md @@ -6,5 +6,5 @@ - βœ… qlora w/ deepspeed Zero-3 needs at least 2x GPUs and 67GiB VRAM (wtf?) - βœ… qlora single-gpu, ~51GiB VRAM - βœ… multipack -- ❓ FSDP +- βœ… FSDP - ❓ 8-bit LoRA diff --git a/examples/jamba/qlora_fsdp_large.yaml b/examples/jamba/qlora_fsdp_large.yaml new file mode 100644 index 000000000..28316efd5 --- /dev/null +++ b/examples/jamba/qlora_fsdp_large.yaml @@ -0,0 +1,61 @@ +base_model: ai21labs/AI21-Jamba-1.5-Large +tokenizer_type: AutoTokenizer + +load_in_4bit: true +strict: false +use_tensorboard: true +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + chat_template: jamba + drop_system_message: true +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: jamba-large-fsdp-qlora-ft +save_safetensors: true +adapter: qlora +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +lora_r: 16 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: [down_proj,gate_proj,in_proj,k_proj,o_proj,out_proj,q_proj,up_proj,v_proj,x_proj] +lora_target_linear: false + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 2 +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 0.00001 + +train_on_inputs: false +group_by_length: false +bf16: true +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: true +logging_steps: 1 +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: false + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD diff --git a/examples/qwen2/qlora-fsdp.yaml b/examples/qwen2/qlora-fsdp.yaml index 44f9c7e49..d61c72a37 100644 --- a/examples/qwen2/qlora-fsdp.yaml +++ b/examples/qwen2/qlora-fsdp.yaml @@ -72,4 +72,5 @@ fsdp_config: fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD special_tokens: diff --git a/examples/tiny-llama/pretrain.yml b/examples/tiny-llama/pretrain.yml index e501dcb8e..010a1608a 100644 --- a/examples/tiny-llama/pretrain.yml +++ b/examples/tiny-llama/pretrain.yml @@ -9,9 +9,9 @@ strict: false max_steps: 200 pretraining_dataset: - path: c4 - name: en - type: pretrain + - path: allenai/c4 + name: en + type: pretrain dataset_prepared_path: val_set_size: 0.0 output_dir: ./outputs/model-out diff --git a/requirements.txt b/requirements.txt index dc74b916f..be0c4927e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ optimum==1.16.2 hf_transfer colorama numba -numpy>=1.24.4 +numpy>=1.24.4,<=2.0.1 # qlora things evaluate==0.4.1 scipy diff --git a/setup.py b/setup.py index 1d164e0a1..1b64fadae 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,7 @@ setup( dependency_links=dependency_links, extras_require={ "flash-attn": [ - "flash-attn==2.6.2", + "flash-attn==2.6.3", ], "fused-dense-lib": [ "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib", diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py new file mode 100644 index 000000000..25408fd57 --- /dev/null +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -0,0 +1,204 @@ +""" +This module provides a CLI to merge sharded FSDP model checkpoints into a single combined checkpoint +""" +import json +import logging +import os +import shutil +from pathlib import Path +from typing import Dict, Union + +import fire +import torch +import torch.distributed.checkpoint as dist_cp +import torch.distributed.checkpoint.format_utils as dist_cp_format_utils +import transformers +from accelerate.utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + is_torch_version, +) +from dotenv import load_dotenv +from huggingface_hub import split_torch_state_dict_into_shards +from safetensors.torch import save_file as safe_save_file +from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner + +from axolotl.cli import load_cfg, print_axolotl_text_art +from axolotl.common.cli import TrainerCliArgs + +LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights") + + +class BFloat16CastPlanner(_EmptyStateDictLoadPlanner): + """ + A custom planner to cast tensors to bfloat16 on the fly during loading. + """ + + def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument + tensor.copy_(tensor.to(torch.bfloat16)) + + +def _distributed_checkpoint_to_merged_weights( + checkpoint_dir: Union[str, Path], + save_path: str, + safe_serialization: bool = False, + max_shard_size: str = "5GB", +): + """ + Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save` + + Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`. + """ + + state_dict: Dict = {} + save_path_ = Path(save_path) + save_path_.mkdir(exist_ok=True) + dist_cp_format_utils._load_state_dict( # pylint: disable=protected-access + state_dict, + storage_reader=dist_cp.FileSystemReader(checkpoint_dir), + planner=BFloat16CastPlanner(), # pylint: disable=protected-access + no_dist=True, + ) + + # To handle if state is a dict like {model: {...}} + if len(state_dict.keys()) == 1: + state_dict = state_dict[list(state_dict)[0]] + + # Ensure all tensors are in bfloat16 + for key, value in state_dict.items(): + if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16: + state_dict[key] = value.to(torch.bfloat16) + + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + + filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + ) + # Save index if sharded + index = None + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + + # Save the model + filename_to_tensors = state_dict_split.filename_to_tensors.items() + + for shard_file, tensors in filename_to_tensors: + shard = {tensor: state_dict[tensor] for tensor in tensors} + + if safe_serialization: + safe_save_file( + shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"} + ) + else: + torch.save(shard, os.path.join(save_path_, shard_file)) + + if index is not None: + save_index_file = ( + SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + ) + save_index_file = os.path.join(save_path_, save_index_file) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as fout: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + fout.write(content) + + return save_path_ + + +def merge_fsdp_weights( + checkpoint_dir: str, + output_path: str, + safe_serialization: bool = False, + remove_checkpoint_dir: bool = False, +): + """ + Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if + `SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if + `safe_serialization` else `pytorch_model.bin`. + + Note: this is a CPU-bound process. + + Args: + checkpoint_dir (`str`): + The directory containing the FSDP checkpoints (can be either the model or optimizer). + output_path (`str`): + The path to save the merged checkpoint. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the merged weights with safetensors (recommended). + remove_checkpoint_dir (`bool`, *optional*, defaults to `False`): + Whether to remove the checkpoint directory after merging. + """ + checkpoint_dir_ = Path(checkpoint_dir) + from accelerate.state import PartialState + + if not is_torch_version(">=", "2.3.0"): + raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`") + + # Verify that the checkpoint directory exists + if not checkpoint_dir_.exists(): + model_path_exists = (checkpoint_dir_ / "pytorch_model_fsdp_0").exists() + optimizer_path_exists = (checkpoint_dir_ / "optimizer_0").exists() + err = f"Tried to load from {checkpoint_dir_} but couldn't find a valid metadata file." + if model_path_exists and optimizer_path_exists: + err += ( + " However, potential model and optimizer checkpoint directories exist." + ) + err += f"Please pass in either {checkpoint_dir_}/pytorch_model_fsdp_0 or {checkpoint_dir_}/optimizer_0" + err += "instead." + elif model_path_exists: + err += " However, a potential model checkpoint directory exists." + err += ( + f"Please try passing in {checkpoint_dir_}/pytorch_model_fsdp_0 instead." + ) + elif optimizer_path_exists: + err += " However, a potential optimizer checkpoint directory exists." + err += f"Please try passing in {checkpoint_dir_}/optimizer_0 instead." + raise ValueError(err) + + # To setup `save` to work + state = PartialState() + if state.is_main_process: + LOG.info(f"Merging FSDP weights from {checkpoint_dir_}") + save_path = _distributed_checkpoint_to_merged_weights( + checkpoint_dir_, output_path, safe_serialization + ) + LOG.info(f"Successfully merged FSDP weights and saved to {save_path}") + if remove_checkpoint_dir: + LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}") + shutil.rmtree(checkpoint_dir_) + state.wait_for_everyone() + + +def do_cli(config: Path = Path("examples/"), **kwargs): + # pylint: disable=duplicate-code + print_axolotl_text_art() + parser = transformers.HfArgumentParser((TrainerCliArgs)) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + parsed_cli_args.merge_lora = True + + parsed_cfg = load_cfg( + config, + **kwargs, + ) + + fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0" + merge_fsdp_weights( + checkpoint_dir=str(fsdp_dir), + output_path=str(Path(parsed_cfg.output_dir) / "merged"), + safe_serialization=True, + ) + + +if __name__ == "__main__": + load_dotenv() + fire.Fire(do_cli) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4e8b36905..1a073ca04 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1846,6 +1846,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): ) if self.cfg.fsdp: ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype) + if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model: + ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype) dpo_trainer = self.hook_post_create_trainer(dpo_trainer) for callback in self.get_post_trainer_create_callbacks(dpo_trainer): diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 904352010..44fc4cb47 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -17,6 +17,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "qwen2_moe", "falcon", "phi", + "phi3", "gemma", "gemma2", "gemmoe", diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 11c8aba7a..8240d8a28 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -361,7 +361,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): "train_on_inputs": cfg.train_on_inputs, "sequence_len": cfg.sequence_len, "roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]), - "train_on_eos": ds_cfg.get("train_on_eos", "last"), + "train_on_eos": ds_cfg.get("train_on_eos", "turn"), } strategy = ChatTemplateStrategy( diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 0ffa3e55f..13ff450f8 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -65,8 +65,10 @@ class AlpacaPrompter(Prompter): self.system_format = "<|im_start|>system\n{system}<|im_end|>\n" elif self.prompt_style == PromptStyle.PHI.value: self.turn_format = "<|user|>\n{instruction}<|end|>{input}<|assistant|>" - self.turn_no_input_format = "<|user|>\n{instruction}<|end|><|assistant|>" - self.system_format = "<|system|>{system}\n" + self.turn_no_input_format = ( + "<|user|>\n{instruction}<|end|>\n<|assistant|>\n" + ) + self.system_format = "<|system|>\n{system}<|end|>\n" def _build_result(self, instruction, input_text, output): # returns the full prompt from instruction and optional input diff --git a/src/axolotl/train.py b/src/axolotl/train.py index b8890d4f7..b21b0b269 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -12,6 +12,7 @@ import torch import transformers.modelcard from accelerate import Accelerator from accelerate.logging import get_logger +from accelerate.utils import save_fsdp_model from datasets import Dataset from peft import PeftModel from pkg_resources import get_distribution # type: ignore @@ -194,9 +195,12 @@ def train( if hasattr(module, "_post_training"): module._post_training(model, name) # pylint: disable=protected-access + state_dict_type = "FULL_STATE_DICT" if trainer.is_fsdp_enabled: - trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") - LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.") + if cfg.fsdp_final_state_dict_type: + state_dict_type = cfg.fsdp_final_state_dict_type + trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type) + LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.") if cfg.relora_steps: if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): @@ -208,7 +212,18 @@ def train( # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.fsdp: - trainer.save_model(cfg.output_dir) + if ( + state_dict_type == "SHARDED_STATE_DICT" + and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT" + ): + save_fsdp_model( + trainer.accelerator.state.fsdp_plugin, + trainer.accelerator, + trainer.model, + cfg.output_dir, + ) + elif state_dict_type == "FULL_STATE_DICT": + trainer.save_model(cfg.output_dir) elif cfg.deepspeed and is_deepspeed_zero3_enabled(): # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading trainer.accelerator.wait_for_everyone() diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index ad161a49c..55d70fd9e 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -23,6 +23,7 @@ _TEMPLATES = { "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", + "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', } diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index aa5eea6af..89cd36784 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -40,6 +40,7 @@ class ChatTemplate(str, Enum): llama3 = "llama3" # pylint: disable=invalid-name phi_3 = "phi_3" # pylint: disable=invalid-name deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name + jamba = "jamba" # pylint: disable=invalid-name tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name @@ -650,6 +651,9 @@ class AxolotlInputConfig( deepspeed: Optional[Union[str, Dict[str, Any]]] = None fsdp: Optional[List[str]] = None fsdp_config: Optional[Dict[str, Any]] = None + fsdp_final_state_dict_type: Optional[ + Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] + ] = None val_set_size: Optional[float] = Field(default=0.0) @@ -1186,6 +1190,20 @@ class AxolotlInputConfig( ) return data + @model_validator(mode="before") + @classmethod + def check_fsdp_sharded_state_dict_w_safetensors(cls, data): + if ( + data.get("fsdp") + and data.get("save_safetensors") + and data.get("fsdp_config") + and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT" + ): + raise ValueError( + "FSDP SHARDED_STATE_DICT not compatible with save_safetensors" + ) + return data + @model_validator(mode="before") @classmethod def check_causal_lm_evals(cls, data): diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index e056c7f50..16f38218c 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -18,10 +18,10 @@ LOG = logging.getLogger("axolotl") def encode_pretraining( - tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] + tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List] ) -> Dict[str, List]: res = tokenizer( - examples, + examples["text"], truncation=True, max_length=max_tokens - 2, add_special_tokens=True, diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index e2c4244f9..3c8feb9b4 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -547,7 +547,9 @@ def load_model( "bnb_4bit_quant_type": "nf4", "bnb_4bit_quant_storage": torch.bfloat16, } - if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed: + if cfg.model_config_type in ["jamba", "qwen2_moe"] and not ( + cfg.deepspeed or cfg.fsdp + ): # for some reason, this causes the loss to be off by an order of magnitude # but deepspeed needs this still in bfloat16 bnb_config["bnb_4bit_quant_storage"] = torch.float32 @@ -592,16 +594,10 @@ def load_model( "flash_attention_2" ) else: - if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES: - model_kwargs["attn_implementation"] = "flash_attention_2" - model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) - else: - model_kwargs["attn_implementation"] = "eager" - model_config._attn_implementation = ( # pylint: disable=protected-access - "eager" - ) + model_kwargs["attn_implementation"] = "flash_attention_2" + model_config._attn_implementation = ( # pylint: disable=protected-access + "flash_attention_2" + ) elif cfg.sdp_attention: model_kwargs["attn_implementation"] = "sdpa" model_config._attn_implementation = "sdpa" # pylint: disable=protected-access @@ -1103,9 +1099,20 @@ def load_lora(model, cfg, inference=False, config_only=False): def ensure_dtype(model, dtype=torch.bfloat16): for name, module in model.named_modules(): + weight_mismatch = False + bias_mismatch = False try: - if module.weight.dtype != dtype: - print(f"Converting module {name}: {module.weight.dtype} -> {dtype}") - module.to(dtype) + weight_mismatch = module.weight.dtype != dtype except AttributeError: pass + try: + bias_mismatch = module.bias.dtype != dtype + except AttributeError: + pass + + if weight_mismatch: + print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}") + if bias_mismatch: + print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}") + if weight_mismatch or bias_mismatch: + module.to(dtype) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 26796f2e5..99c10c655 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -399,12 +399,15 @@ def setup_torch_compile_env(cfg): def setup_deepspeed_env(cfg, stage=None): + from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed if stage: os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage) if stage == 3: os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true" + HfTrainerDeepSpeedConfig(cfg.deepspeed) def setup_fsdp_envs(cfg): diff --git a/tests/e2e/multigpu/test_qwen2.py b/tests/e2e/multigpu/test_qwen2.py new file mode 100644 index 000000000..2513be69e --- /dev/null +++ b/tests/e2e/multigpu/test_qwen2.py @@ -0,0 +1,98 @@ +""" +E2E tests for multigpu qwen2 +""" + +import logging +import os +import unittest +from pathlib import Path + +import yaml +from accelerate.test_utils import execute_subprocess_async + +from axolotl.utils.dict import DictDefault + +from ..utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e.multigpu") +os.environ["WANDB_DISABLED"] = "true" + + +class TestMultiGPUQwen2(unittest.TestCase): + """ + Test case for Llama models using LoRA + """ + + @with_temp_dir + def test_qlora_fsdp_dpo(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2-1.5B", + "load_in_4bit": True, + "rl": "dpo", + "chat_template": "chatml", + "sequence_len": 2048, + "adapter": "qlora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.05, + "datasets": [ + { + "path": "Intel/orca_dpo_pairs", + "split": "train", + "type": "chatml.intel", + }, + ], + "num_epochs": 1, + "max_steps": 100, + "warmup_steps": 20, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "flash_attention": True, + "bf16": "auto", + "tf32": True, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": { + "use_reentrant": False, + }, + "fsdp": [ + "full_shard", + "auto_wrap", + ], + "fsdp_config": { + "fsdp_limit_all_gathers": True, + "fsdp_offload_params": False, + "fsdp_sync_module_states": True, + "fsdp_use_orig_params": False, + "fsdp_cpu_ram_efficient_loading": False, + "fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_sharding_strategy": "FULL_SHARD", + }, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "accelerate", + "launch", + "--num-processes", + "2", + "-m", + "axolotl.cli.train", + str(Path(temp_dir) / "config.yaml"), + ] + ) diff --git a/tests/test_data.py b/tests/test_data.py index 16af089a0..9d7f5a041 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -35,7 +35,7 @@ class TestEncodePretraining(unittest.TestCase): "hello, hello", ] } - result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"]) + result = encode_pretraining(self.tokenizer, self.max_tokens, examples) self.assertEqual(len(result["input_ids"]), 3) diff --git a/tests/test_prompters.py b/tests/test_prompters.py index 6c5b8f27c..3d61398e0 100644 --- a/tests/test_prompters.py +++ b/tests/test_prompters.py @@ -42,6 +42,19 @@ class AlpacaPrompterTest(unittest.TestCase): assert "USER:" not in res assert "ASSISTANT:" not in res + def test_prompt_style_w_phi(self): + prompter = AlpacaPrompter(prompt_style=PromptStyle.PHI.value) + res = next(prompter.build_prompt("tell me a joke about the following")) + assert ( + """<|system|> +Below is an instruction that describes a task. Write a response that appropriately completes the request.<|end|> +<|user|> +tell me a joke about the following<|end|> +<|assistant|> +""" + == res + ) + def test_prompt_style_w_chat(self): prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value) res = next(