Compare commits
4 Commits
20240216-u
...
feat/space
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39ad38a1fb | ||
|
|
ddb60883f5 | ||
|
|
a5724ef08d | ||
|
|
190930b5df |
2
.github/workflows/base.yml
vendored
2
.github/workflows/base.yml
vendored
@@ -7,7 +7,7 @@ jobs:
|
||||
build-base:
|
||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: axolotl-gpu-runner
|
||||
runs-on: self-hosted
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
||||
18
.github/workflows/main.yml
vendored
18
.github/workflows/main.yml
vendored
@@ -9,6 +9,7 @@ on:
|
||||
jobs:
|
||||
build-axolotl:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -34,7 +35,7 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.1.2
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -55,16 +56,27 @@ jobs:
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
load: true
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
- name: Unit Tests
|
||||
run: |
|
||||
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
- name: Push to Docker Hub
|
||||
if: github.event_name != 'pull_request'
|
||||
run: |
|
||||
docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
if [ -n "$latest_tag" ]; then
|
||||
docker push "$latest_tag"
|
||||
fi
|
||||
|
||||
build-axolotl-runpod:
|
||||
needs: build-axolotl
|
||||
@@ -94,7 +106,7 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.1.2
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -32,9 +32,6 @@ ignore_missing_imports = True
|
||||
[mypy-bitsandbytes]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-requests]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-datasets]
|
||||
ignore_missing_imports = True
|
||||
|
||||
|
||||
29
README.md
29
README.md
@@ -25,8 +25,8 @@ Features:
|
||||
- [Installation](#installation)
|
||||
- [Docker](#docker)
|
||||
- [Conda/Pip venv](#condapip-venv)
|
||||
- [Cloud GPU](#cloud-gpu) - Latitude.sh, RunPod
|
||||
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
|
||||
- [Cloud GPU](#cloud-gpu) - Runpod, Latitude
|
||||
- [LambdaLabs](#lambdalabs)
|
||||
- [Windows](#windows)
|
||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||
- [Dataset](#dataset)
|
||||
@@ -121,10 +121,6 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
# gradio
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--lora_model_dir="./lora-out" --gradio
|
||||
|
||||
# remote yaml files - the yaml config can be hosted on a public URL
|
||||
# Note: the yaml config must directly link to the **raw** yaml
|
||||
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/openllama-3b/lora.yml
|
||||
```
|
||||
|
||||
## Installation
|
||||
@@ -186,13 +182,9 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
||||
|
||||
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
|
||||
|
||||
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
|
||||
#### Bare Metal Cloud GPU
|
||||
|
||||
##### LambdaLabs
|
||||
|
||||
#### LambdaLabs
|
||||
<details>
|
||||
|
||||
<summary>Click to Expand</summary>
|
||||
@@ -472,12 +464,6 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
dataset:
|
||||
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
|
||||
...
|
||||
|
||||
# Loading Data From a Public URL
|
||||
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
|
||||
dataset:
|
||||
- path: https://some.url.com/yourdata.jsonl # The URL should be a direct link to the file you wish to load. URLs must use HTTPS protocol, not HTTP.
|
||||
ds_type: json # this is the default, see other options below.
|
||||
```
|
||||
|
||||
- loading
|
||||
@@ -990,9 +976,6 @@ Run
|
||||
accelerate launch -m axolotl.cli.train your_config.yml
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> You can also reference a config file that is hosted on a public URL, for example `accelerate launch -m axolotl.cli.train https://yourdomain.com/your_config.yml`
|
||||
|
||||
#### Preprocess dataset
|
||||
|
||||
You can optionally pre-tokenize dataset with the following before finetuning.
|
||||
@@ -1217,12 +1200,6 @@ pre-commit install
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
|
||||
|
||||
<a href="https://github.com/openaccess-ai-collective/axolotl/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
|
||||
</a>
|
||||
|
||||
## Sponsors 🤝❤
|
||||
|
||||
OpenAccess AI Collective is run by volunteer contributors such as [winglian](https://github.com/winglian),
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
eval_sample_packing: false
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: false
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 0
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -1,4 +1,3 @@
|
||||
pre-commit
|
||||
black
|
||||
mypy
|
||||
types-requests
|
||||
|
||||
@@ -9,7 +9,6 @@ deepspeed>=0.13.1
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
requests
|
||||
datasets>=2.15.0
|
||||
flash-attn==2.3.3
|
||||
sentencepiece
|
||||
|
||||
24
setup.py
24
setup.py
@@ -1,7 +1,5 @@
|
||||
"""setup.py for axolotl"""
|
||||
|
||||
import platform
|
||||
import re
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
@@ -28,25 +26,11 @@ def parse_requirements():
|
||||
_install_requires.append(line)
|
||||
|
||||
try:
|
||||
if "Darwin" in platform.system():
|
||||
torch_version = version("torch")
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
if torch_version.startswith("2.1."):
|
||||
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
||||
else:
|
||||
torch_version = version("torch")
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
if version_match:
|
||||
major, minor, patch = version_match.groups()
|
||||
major, minor = int(major), int(minor)
|
||||
patch = (
|
||||
int(patch) if patch is not None else 0
|
||||
) # Default patch to 0 if not present
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 1):
|
||||
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
||||
_install_requires.append("xformers>=0.0.23")
|
||||
_install_requires.append("xformers>=0.0.23")
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
@@ -63,52 +59,6 @@ def print_axolotl_text_art(suffix=None):
|
||||
print(ascii_art)
|
||||
|
||||
|
||||
def check_remote_config(config: Union[str, Path]):
|
||||
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
|
||||
if not (isinstance(config, str) and config.startswith("https://")):
|
||||
return config # Return the original value if it's not a valid URL
|
||||
|
||||
filename = os.path.basename(urlparse(config).path)
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
|
||||
try:
|
||||
response = requests.get(config, timeout=30)
|
||||
response.raise_for_status() # Check for HTTP errors
|
||||
|
||||
content = response.content
|
||||
try:
|
||||
# Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML
|
||||
json.loads(content)
|
||||
# Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link
|
||||
LOG.warning(
|
||||
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# If it's not valid JSON, verify it's valid YAML
|
||||
try:
|
||||
yaml.safe_load(content)
|
||||
except yaml.YAMLError as err:
|
||||
raise ValueError(
|
||||
f"Failed to parse the content at {config} as YAML: {err}"
|
||||
) from err
|
||||
|
||||
# Write the content to a file if it's valid YAML (or JSON treated as YAML)
|
||||
output_path = Path(temp_dir) / filename
|
||||
with open(output_path, "wb") as file:
|
||||
file.write(content)
|
||||
LOG.info(
|
||||
f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n"
|
||||
)
|
||||
return output_path
|
||||
|
||||
except requests.RequestException as err:
|
||||
# This catches all requests-related exceptions including HTTPError
|
||||
raise RuntimeError(f"Failed to download {config}: {err}") from err
|
||||
except Exception as err:
|
||||
# Catch-all for any other exceptions
|
||||
raise err
|
||||
|
||||
|
||||
def get_multi_line_input() -> Optional[str]:
|
||||
print("Give me an instruction (Ctrl + D to submit): ")
|
||||
instruction = ""
|
||||
@@ -320,10 +270,9 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
|
||||
return not any(el in list2 for el in list1)
|
||||
|
||||
|
||||
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
config = check_remote_config(config)
|
||||
def load_cfg(config: Path = Path("examples/"), **kwargs):
|
||||
if Path(config).is_dir():
|
||||
config = choose_config(Path(config))
|
||||
config = choose_config(config)
|
||||
|
||||
# load the config from the yaml file
|
||||
with open(config, encoding="utf-8") as file:
|
||||
|
||||
@@ -3,7 +3,6 @@ CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
@@ -24,7 +23,7 @@ from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
|
||||
@@ -3,7 +3,6 @@ CLI to shard a trained model into 10GiB chunks
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
@@ -26,7 +25,7 @@ def shard(
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
|
||||
@@ -3,7 +3,7 @@ CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Union
|
||||
from typing import Tuple
|
||||
|
||||
import fire
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
@@ -25,7 +25,7 @@ from axolotl.train import train
|
||||
LOG = logging.getLogger("axolotl.cli.train")
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser((TrainerCliArgs))
|
||||
|
||||
@@ -28,7 +28,6 @@ from transformers import (
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
@@ -49,7 +48,7 @@ from axolotl.utils.collators import (
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.schedulers import (
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup, JaggedLRRestartScheduler,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -129,19 +128,7 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
)
|
||||
relora_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many anneal steps to take before reset for ReLoRA"},
|
||||
)
|
||||
jagged_restart_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for jagged restarts"},
|
||||
)
|
||||
jagged_restarts_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for jagged restarts"},
|
||||
)
|
||||
jagged_restarts_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many anneal steps to take before reset for jagged restarts"},
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
@@ -238,7 +225,7 @@ class AxolotlTrainer(Trainer):
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
super().create_scheduler(num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||
@@ -246,21 +233,6 @@ class AxolotlTrainer(Trainer):
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
|
||||
if self.args.jagged_restart_steps:
|
||||
warmup_steps = (
|
||||
self.args.jagged_restarts_warmup_steps or 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.jagged_restarts_anneal_steps or 1
|
||||
)
|
||||
self.lr_scheduler = JaggedLRRestartScheduler(
|
||||
optimizer,
|
||||
self.lr_scheduler,
|
||||
self.args.jagged_restart_steps,
|
||||
warmup_steps,
|
||||
anneal_steps,
|
||||
)
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
@@ -900,8 +872,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["optim"] = (
|
||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||
)
|
||||
if self.cfg.save_only_model:
|
||||
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
|
||||
training_arguments_kwargs["lr_scheduler_type"] = (
|
||||
self.cfg.lr_scheduler
|
||||
if self.cfg.lr_scheduler
|
||||
@@ -1024,7 +994,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
]
|
||||
]
|
||||
if use_batch_sampler_collator:
|
||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||
if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
elif (
|
||||
self.cfg.model_config_type in ["llama"]
|
||||
|
||||
12
src/axolotl/monkeypatch/falcon/__init__.py
Normal file
12
src/axolotl/monkeypatch/falcon/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Patches to support multipack for falcon
|
||||
"""
|
||||
import transformers
|
||||
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
|
||||
def replace_falcon_attn_with_multipack_flash_attn():
|
||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
@@ -2,6 +2,9 @@
|
||||
Patches to support multipack for mixtral
|
||||
"""
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
|
||||
def patch_mixtral_moe_forward_zero3() -> None:
|
||||
@@ -48,3 +51,11 @@ def patch_mixtral_moe_forward_zero3() -> None:
|
||||
|
||||
MixtralBLockSparseTop2MLP.forward = mlp_forward
|
||||
MixtralSparseMoeBlock.forward = moe_forward
|
||||
|
||||
|
||||
def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
|
||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
if for_zero3:
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
"""multipack patching for v2 of sample packing"""
|
||||
|
||||
import transformers
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
|
||||
|
||||
|
||||
def patch_for_multipack(model_type):
|
||||
if model_type == "mixtral":
|
||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
elif model_type == "qwen2":
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "falcon":
|
||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "phi":
|
||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
12
src/axolotl/monkeypatch/phi/__init__.py
Normal file
12
src/axolotl/monkeypatch/phi/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Patches to support multipack for phi2
|
||||
"""
|
||||
import transformers
|
||||
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
|
||||
def replace_phi_attn_with_multipack_flash_attn():
|
||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
12
src/axolotl/monkeypatch/qwen2/__init__.py
Normal file
12
src/axolotl/monkeypatch/qwen2/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Patches to support multipack for qwen2
|
||||
"""
|
||||
import transformers
|
||||
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
|
||||
def replace_qwen2_attn_with_multipack_flash_attn():
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
@@ -186,8 +186,8 @@ def mask_2d_to_4d(
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
binary_mask = torch.where(
|
||||
mask != 0,
|
||||
torch.tensor(1, device=mask.device).to(dtype),
|
||||
torch.tensor(0, device=mask.device).to(dtype),
|
||||
torch.tensor(1).to(dtype),
|
||||
torch.tensor(0).to(dtype),
|
||||
)
|
||||
|
||||
# Create a block-diagonal mask.
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import Prompter
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
|
||||
|
||||
class ChatTemplatePrompter(Prompter):
|
||||
def __init__(self, tokenizer, chat_template=None, max_length=2048):
|
||||
self.tokenizer = tokenizer
|
||||
self.chat_template = chat_template
|
||||
self.max_length = max_length
|
||||
|
||||
def build_prompt(self, conversation, add_generation_prompt=False):
|
||||
return self.tokenizer.apply_chat_template(
|
||||
conversation, truncation=True, max_length=self.max_length,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
chat_template=self.chat_template,
|
||||
)
|
||||
|
||||
|
||||
class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for instruction-based prompts.
|
||||
"""
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
turns = self.get_conversation_thread(prompt)
|
||||
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
|
||||
input_ids = self.prompter.build_prompt(turns)
|
||||
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(prompt_ids)
|
||||
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
|
||||
else:
|
||||
labels = input_ids
|
||||
|
||||
|
||||
tokenized_prompt = {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": [1] * len(input_ids)
|
||||
}
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
# remap roles - allow for assistant turn
|
||||
role_map = {"human": "user", "user": "user", "assistant": "assistant", "gpt": "assistant"}
|
||||
turns = [
|
||||
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_templates(ds_cfg["conversation"]),
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
return strategy
|
||||
@@ -208,10 +208,7 @@ def train(
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
if not cfg.hub_model_id:
|
||||
try:
|
||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||
except AttributeError:
|
||||
pass
|
||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||
elif cfg.hub_model_id:
|
||||
# defensively push to the hub to ensure the model card is updated
|
||||
trainer.push_to_hub()
|
||||
|
||||
@@ -47,12 +47,6 @@ def gpu_memory_usage_all(device=0):
|
||||
return usage, reserved - usage, max(0, smi - reserved)
|
||||
|
||||
|
||||
def mps_memory_usage_all():
|
||||
usage = torch.mps.current_allocated_memory() / 1024.0**3
|
||||
reserved = torch.mps.driver_allocated_memory() / 1024.0**3
|
||||
return usage, reserved - usage, 0
|
||||
|
||||
|
||||
@check_cuda_device(0.0)
|
||||
def gpu_memory_usage_smi(device=0):
|
||||
if isinstance(device, torch.device):
|
||||
@@ -69,10 +63,7 @@ def gpu_memory_usage_smi(device=0):
|
||||
|
||||
|
||||
def log_gpu_memory_usage(log, msg, device):
|
||||
if torch.backends.mps.is_available():
|
||||
usage, cache, misc = mps_memory_usage_all()
|
||||
else:
|
||||
usage, cache, misc = gpu_memory_usage_all(device)
|
||||
usage, cache, misc = gpu_memory_usage_all(device)
|
||||
extras = []
|
||||
if cache > 0:
|
||||
extras.append(f"+{cache:.03f}GB cache")
|
||||
|
||||
@@ -62,7 +62,7 @@ class EvalFirstStepCallback(
|
||||
):
|
||||
if (
|
||||
args.evaluation_strategy == IntervalStrategy.STEPS
|
||||
and (args.eval_steps < 1.0 or args.eval_steps > 1)
|
||||
and args.eval_steps < 1.0
|
||||
and state.global_step == 1
|
||||
):
|
||||
control.should_evaluate = True
|
||||
|
||||
@@ -336,16 +336,6 @@ def load_tokenized_prepared_datasets(
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
elif config_dataset.path.startswith("https://"):
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
else:
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
fp = hf_hub_download(
|
||||
|
||||
@@ -29,10 +29,6 @@ from transformers import ( # noqa: F401
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||
patch_for_multipack,
|
||||
)
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
@@ -303,15 +299,8 @@ def load_model(
|
||||
shifted-sparse attention does not currently support sample packing."
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
and cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
patch_for_multipack(cfg.model_config_type)
|
||||
elif cfg.is_llama_derived_model:
|
||||
# Modify all llama derived models in one block
|
||||
|
||||
# Modify all llama derived models in one block
|
||||
if cfg.is_llama_derived_model:
|
||||
if cfg.flash_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
replace_llama_attn_with_flash_attn,
|
||||
@@ -365,6 +354,43 @@ def load_model(
|
||||
LOG.info("patching mistral with flash attention")
|
||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||
|
||||
if (
|
||||
cfg.model_config_type == "mixtral"
|
||||
and cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
from axolotl.monkeypatch.mixtral import (
|
||||
replace_mixtral_attn_with_multipack_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching mixtral with flash attention")
|
||||
mixtral_patch_kwargs = {}
|
||||
if is_deepspeed_zero3_enabled():
|
||||
mixtral_patch_kwargs["for_zero3"] = True
|
||||
replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)
|
||||
|
||||
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
|
||||
from axolotl.monkeypatch.falcon import (
|
||||
replace_falcon_attn_with_multipack_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching falcon with flash attention")
|
||||
replace_falcon_attn_with_multipack_flash_attn()
|
||||
|
||||
if cfg.model_config_type == "phi" and cfg.flash_attention and cfg.sample_packing:
|
||||
from axolotl.monkeypatch.phi import replace_phi_attn_with_multipack_flash_attn
|
||||
|
||||
LOG.info("patching phi with flash attention")
|
||||
replace_phi_attn_with_multipack_flash_attn()
|
||||
|
||||
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
|
||||
from axolotl.monkeypatch.qwen2 import (
|
||||
replace_qwen2_attn_with_multipack_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching qwen2 with flash attention")
|
||||
replace_qwen2_attn_with_multipack_flash_attn()
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||
|
||||
@@ -374,7 +400,7 @@ def load_model(
|
||||
model_kwargs: Dict[str, Any] = {}
|
||||
|
||||
if cfg.model_kwargs:
|
||||
for key, val in cfg.model_kwargs.items():
|
||||
for key, val in model_kwargs.items():
|
||||
model_kwargs[key] = val
|
||||
|
||||
max_memory = cfg.max_memory
|
||||
@@ -409,10 +435,6 @@ def load_model(
|
||||
|
||||
model_kwargs["device_map"] = device_map
|
||||
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
model_kwargs["device_map"] = "mps:0"
|
||||
|
||||
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
|
||||
# if cfg.rl:
|
||||
# if torch.cuda.device_count() > 1:
|
||||
@@ -479,7 +501,7 @@ def load_model(
|
||||
"flash_attention_2"
|
||||
)
|
||||
else:
|
||||
if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||
if model_config.model_type in ["mixtral", "qwen2", "falcon", "phi"]:
|
||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flash_attention_2"
|
||||
@@ -655,7 +677,7 @@ def load_model(
|
||||
):
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
if hasattr(model, "device") and model.device.type in ("cuda", "mps"):
|
||||
if hasattr(model, "device") and model.device.type == "cuda":
|
||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||
|
||||
# make sure these are fp32 per Ramesh et al. (2021)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Module for custom LRScheduler class"""
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
@@ -141,48 +140,3 @@ def get_cosine_schedule_with_min_lr(
|
||||
min_lr_ratio=min_lr_ratio,
|
||||
)
|
||||
return LambdaLR(optimizer, lr_lambda)
|
||||
|
||||
|
||||
class JaggedLRRestartScheduler(LRScheduler):
|
||||
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
inner_schedule: LRScheduler,
|
||||
jagged_restarts_steps: int,
|
||||
jagged_restarts_warmup_steps: int,
|
||||
jagged_restarts_anneal_steps: int = 1,
|
||||
min_lr_scale: float = 0.001,
|
||||
) -> None:
|
||||
self.inner_schedule = inner_schedule
|
||||
self.restarts_steps = jagged_restarts_steps
|
||||
self.warmup_steps = jagged_restarts_warmup_steps
|
||||
self.anneal_steps = jagged_restarts_anneal_steps
|
||||
self.min_lr_scale = min_lr_scale
|
||||
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
||||
|
||||
def get_lr(self) -> float:
|
||||
self.inner_schedule.last_epoch = self.last_epoch
|
||||
|
||||
original = self.inner_schedule.get_lr()
|
||||
step = self.last_epoch
|
||||
|
||||
if step < self.restarts_steps:
|
||||
scale = 1
|
||||
else:
|
||||
per_relora_progress = step % self.restarts_steps
|
||||
if per_relora_progress < self.warmup_steps:
|
||||
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
|
||||
elif per_relora_progress > (self.restarts_steps - self.anneal_steps):
|
||||
cycle_t = min(
|
||||
1.0,
|
||||
(self.restarts_steps - per_relora_progress) / self.anneal_steps,
|
||||
)
|
||||
else:
|
||||
cycle_t = 1
|
||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||
|
||||
if isinstance(original, Sequence):
|
||||
return [lr * scale for lr in original]
|
||||
return original * scale
|
||||
|
||||
98
ui/main.py
Normal file
98
ui/main.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
This module is used to launch Axolotl with user defined configurations.
|
||||
"""
|
||||
|
||||
import gradio as gr
|
||||
import yaml
|
||||
|
||||
|
||||
def config(
|
||||
base_model,
|
||||
dataset,
|
||||
dataset_type,
|
||||
learn_rate,
|
||||
gradient_accumulation_steps,
|
||||
micro_batch_size,
|
||||
seq_length,
|
||||
num_epochs,
|
||||
output_dir,
|
||||
val_size,
|
||||
):
|
||||
"""
|
||||
This function generates a configuration dictionary and saves it as a yaml file.
|
||||
"""
|
||||
config_dict = {
|
||||
"base_model": base_model,
|
||||
"datasets": [{"path": dataset, "type": dataset_type}],
|
||||
"learning_rate": learn_rate,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"micro_batch_size": micro_batch_size,
|
||||
"sequence_len": seq_length,
|
||||
"num_epochs": num_epochs,
|
||||
"output_dir": output_dir,
|
||||
"val_set_size": val_size,
|
||||
}
|
||||
with open("config.yml", "w", encoding="utf-8") as file:
|
||||
yaml.dump(config_dict, file)
|
||||
print(config_dict)
|
||||
return yaml.dump(config_dict)
|
||||
|
||||
|
||||
with gr.Blocks(title="Axolotl Launcher") as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
# Axolotl Launcher
|
||||
Fill out the required fields below to create a training run.
|
||||
"""
|
||||
)
|
||||
with gr.Row():
|
||||
base_model_name = gr.Textbox(
|
||||
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", label="Base model"
|
||||
)
|
||||
|
||||
mode = gr.Radio(
|
||||
choices=["Full finetune", "QLoRA", "LoRA"],
|
||||
label="Training mode",
|
||||
info="FFT = 16 bit, Qlora = 4 bit, Lora = 8 bit",
|
||||
)
|
||||
with gr.Row():
|
||||
dataset_path = gr.Textbox("mhenrichsen/alpaca_2k_test", label="Dataset")
|
||||
dataset_type_name = gr.Dropdown(
|
||||
choices=["alpaca", "sharegpt"], label="Dataset type", value="alpaca"
|
||||
)
|
||||
with gr.Accordion("Hyperparameters", open=False):
|
||||
gr.Markdown("Choose hyperparameters")
|
||||
with gr.Row():
|
||||
learning_rate = gr.Number(0.000001, label="Learning rate")
|
||||
gradient_accumulation_steps_count = gr.Number(
|
||||
1, label="Gradient accumulation steps"
|
||||
)
|
||||
val_set_size_count = gr.Number(0, label="Validation size")
|
||||
|
||||
with gr.Row():
|
||||
micro_batch_size_count = gr.Number(1, label="Micro batch size")
|
||||
sequence_length = gr.Number(1024, label="Sequence length")
|
||||
num_epochs_count = gr.Number(1, label="Epochs")
|
||||
|
||||
output_dir_path = gr.Textbox("./model-out", label="Output directory")
|
||||
|
||||
create_config = gr.Button("Create config")
|
||||
output = gr.TextArea(label="Generated config")
|
||||
create_config.click(
|
||||
config,
|
||||
inputs=[
|
||||
base_model_name,
|
||||
dataset_path,
|
||||
dataset_type_name,
|
||||
learning_rate,
|
||||
gradient_accumulation_steps_count,
|
||||
micro_batch_size_count,
|
||||
sequence_length,
|
||||
num_epochs_count,
|
||||
output_dir_path,
|
||||
val_set_size_count,
|
||||
],
|
||||
outputs=output,
|
||||
)
|
||||
|
||||
demo.launch(debug=True, server_name="0.0.0.0", server_port=7860)
|
||||
Reference in New Issue
Block a user