Compare commits
12 Commits
feat/space
...
streaming-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e08df47584 | ||
|
|
fac2d98c26 | ||
|
|
ea00dd0852 | ||
|
|
b2a4cb4396 | ||
|
|
aaf54dc730 | ||
|
|
9bca7db133 | ||
|
|
91cf4ee72c | ||
|
|
1daecd161e | ||
|
|
4a654b331e | ||
|
|
5698943263 | ||
|
|
411293bdca | ||
|
|
73f1bdaa15 |
2
.github/workflows/base.yml
vendored
2
.github/workflows/base.yml
vendored
@@ -7,7 +7,7 @@ jobs:
|
||||
build-base:
|
||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: self-hosted
|
||||
runs-on: axolotl-gpu-runner
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
||||
18
.github/workflows/main.yml
vendored
18
.github/workflows/main.yml
vendored
@@ -9,7 +9,6 @@ on:
|
||||
jobs:
|
||||
build-axolotl:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -35,7 +34,7 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.1.2
|
||||
axolotl_extras:
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -56,27 +55,16 @@ jobs:
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
load: true
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
- name: Unit Tests
|
||||
run: |
|
||||
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
- name: Push to Docker Hub
|
||||
if: github.event_name != 'pull_request'
|
||||
run: |
|
||||
docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
if [ -n "$latest_tag" ]; then
|
||||
docker push "$latest_tag"
|
||||
fi
|
||||
|
||||
build-axolotl-runpod:
|
||||
needs: build-axolotl
|
||||
@@ -106,7 +94,7 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.1.2
|
||||
axolotl_extras:
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -32,6 +32,9 @@ ignore_missing_imports = True
|
||||
[mypy-bitsandbytes]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-requests]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-datasets]
|
||||
ignore_missing_imports = True
|
||||
|
||||
|
||||
29
README.md
29
README.md
@@ -25,8 +25,8 @@ Features:
|
||||
- [Installation](#installation)
|
||||
- [Docker](#docker)
|
||||
- [Conda/Pip venv](#condapip-venv)
|
||||
- [Cloud GPU](#cloud-gpu) - Runpod, Latitude
|
||||
- [LambdaLabs](#lambdalabs)
|
||||
- [Cloud GPU](#cloud-gpu) - Latitude.sh, RunPod
|
||||
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
|
||||
- [Windows](#windows)
|
||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||
- [Dataset](#dataset)
|
||||
@@ -121,6 +121,10 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
# gradio
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--lora_model_dir="./lora-out" --gradio
|
||||
|
||||
# remote yaml files - the yaml config can be hosted on a public URL
|
||||
# Note: the yaml config must directly link to the **raw** yaml
|
||||
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/openllama-3b/lora.yml
|
||||
```
|
||||
|
||||
## Installation
|
||||
@@ -182,9 +186,13 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
||||
|
||||
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
|
||||
|
||||
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
|
||||
#### LambdaLabs
|
||||
#### Bare Metal Cloud GPU
|
||||
|
||||
##### LambdaLabs
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Click to Expand</summary>
|
||||
@@ -464,6 +472,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
dataset:
|
||||
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
|
||||
...
|
||||
|
||||
# Loading Data From a Public URL
|
||||
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
|
||||
dataset:
|
||||
- path: https://some.url.com/yourdata.jsonl # The URL should be a direct link to the file you wish to load. URLs must use HTTPS protocol, not HTTP.
|
||||
ds_type: json # this is the default, see other options below.
|
||||
```
|
||||
|
||||
- loading
|
||||
@@ -976,6 +990,9 @@ Run
|
||||
accelerate launch -m axolotl.cli.train your_config.yml
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> You can also reference a config file that is hosted on a public URL, for example `accelerate launch -m axolotl.cli.train https://yourdomain.com/your_config.yml`
|
||||
|
||||
#### Preprocess dataset
|
||||
|
||||
You can optionally pre-tokenize dataset with the following before finetuning.
|
||||
@@ -1200,6 +1217,12 @@ pre-commit install
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
|
||||
|
||||
<a href="https://github.com/openaccess-ai-collective/axolotl/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
|
||||
</a>
|
||||
|
||||
## Sponsors 🤝❤
|
||||
|
||||
OpenAccess AI Collective is run by volunteer contributors such as [winglian](https://github.com/winglian),
|
||||
|
||||
65
examples/tiny-llama/lora-mps.yml
Normal file
65
examples/tiny-llama/lora-mps.yml
Normal file
@@ -0,0 +1,65 @@
|
||||
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
eval_sample_packing: false
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: false
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 0
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -10,9 +10,9 @@ strict: false
|
||||
|
||||
max_steps: 200
|
||||
pretraining_dataset:
|
||||
path: c4
|
||||
name: en
|
||||
type: pretrain
|
||||
- path: c4
|
||||
name: en
|
||||
type: pretrain
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./model-out
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pre-commit
|
||||
black
|
||||
mypy
|
||||
types-requests
|
||||
|
||||
@@ -9,6 +9,7 @@ deepspeed>=0.13.1
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
requests
|
||||
datasets>=2.15.0
|
||||
flash-attn==2.3.3
|
||||
sentencepiece
|
||||
|
||||
24
setup.py
24
setup.py
@@ -1,5 +1,7 @@
|
||||
"""setup.py for axolotl"""
|
||||
|
||||
import platform
|
||||
import re
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
@@ -26,11 +28,25 @@ def parse_requirements():
|
||||
_install_requires.append(line)
|
||||
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
if torch_version.startswith("2.1."):
|
||||
if "Darwin" in platform.system():
|
||||
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
||||
_install_requires.append("xformers>=0.0.23")
|
||||
else:
|
||||
torch_version = version("torch")
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
if version_match:
|
||||
major, minor, patch = version_match.groups()
|
||||
major, minor = int(major), int(minor)
|
||||
patch = (
|
||||
int(patch) if patch is not None else 0
|
||||
) # Default patch to 0 if not present
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 1):
|
||||
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
||||
_install_requires.append("xformers>=0.0.23")
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
@@ -59,6 +63,52 @@ def print_axolotl_text_art(suffix=None):
|
||||
print(ascii_art)
|
||||
|
||||
|
||||
def check_remote_config(config: Union[str, Path]):
|
||||
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
|
||||
if not (isinstance(config, str) and config.startswith("https://")):
|
||||
return config # Return the original value if it's not a valid URL
|
||||
|
||||
filename = os.path.basename(urlparse(config).path)
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
|
||||
try:
|
||||
response = requests.get(config, timeout=30)
|
||||
response.raise_for_status() # Check for HTTP errors
|
||||
|
||||
content = response.content
|
||||
try:
|
||||
# Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML
|
||||
json.loads(content)
|
||||
# Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link
|
||||
LOG.warning(
|
||||
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# If it's not valid JSON, verify it's valid YAML
|
||||
try:
|
||||
yaml.safe_load(content)
|
||||
except yaml.YAMLError as err:
|
||||
raise ValueError(
|
||||
f"Failed to parse the content at {config} as YAML: {err}"
|
||||
) from err
|
||||
|
||||
# Write the content to a file if it's valid YAML (or JSON treated as YAML)
|
||||
output_path = Path(temp_dir) / filename
|
||||
with open(output_path, "wb") as file:
|
||||
file.write(content)
|
||||
LOG.info(
|
||||
f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n"
|
||||
)
|
||||
return output_path
|
||||
|
||||
except requests.RequestException as err:
|
||||
# This catches all requests-related exceptions including HTTPError
|
||||
raise RuntimeError(f"Failed to download {config}: {err}") from err
|
||||
except Exception as err:
|
||||
# Catch-all for any other exceptions
|
||||
raise err
|
||||
|
||||
|
||||
def get_multi_line_input() -> Optional[str]:
|
||||
print("Give me an instruction (Ctrl + D to submit): ")
|
||||
instruction = ""
|
||||
@@ -270,9 +320,10 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
|
||||
return not any(el in list2 for el in list1)
|
||||
|
||||
|
||||
def load_cfg(config: Path = Path("examples/"), **kwargs):
|
||||
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
config = check_remote_config(config)
|
||||
if Path(config).is_dir():
|
||||
config = choose_config(config)
|
||||
config = choose_config(Path(config))
|
||||
|
||||
# load the config from the yaml file
|
||||
with open(config, encoding="utf-8") as file:
|
||||
|
||||
@@ -3,6 +3,7 @@ CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
@@ -23,7 +24,7 @@ from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
|
||||
@@ -3,6 +3,7 @@ CLI to shard a trained model into 10GiB chunks
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
@@ -25,7 +26,7 @@ def shard(
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
|
||||
@@ -3,7 +3,7 @@ CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Union
|
||||
|
||||
import fire
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
@@ -25,7 +25,7 @@ from axolotl.train import train
|
||||
LOG = logging.getLogger("axolotl.cli.train")
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser((TrainerCliArgs))
|
||||
|
||||
@@ -28,6 +28,7 @@ from transformers import (
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
@@ -994,7 +995,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
]
|
||||
]
|
||||
if use_batch_sampler_collator:
|
||||
if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
|
||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
elif (
|
||||
self.cfg.model_config_type in ["llama"]
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
"""
|
||||
Patches to support multipack for falcon
|
||||
"""
|
||||
import transformers
|
||||
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
|
||||
def replace_falcon_attn_with_multipack_flash_attn():
|
||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
@@ -2,9 +2,6 @@
|
||||
Patches to support multipack for mixtral
|
||||
"""
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
|
||||
def patch_mixtral_moe_forward_zero3() -> None:
|
||||
@@ -51,11 +48,3 @@ def patch_mixtral_moe_forward_zero3() -> None:
|
||||
|
||||
MixtralBLockSparseTop2MLP.forward = mlp_forward
|
||||
MixtralSparseMoeBlock.forward = moe_forward
|
||||
|
||||
|
||||
def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
|
||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
if for_zero3:
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
|
||||
30
src/axolotl/monkeypatch/multipack.py
Normal file
30
src/axolotl/monkeypatch/multipack.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""multipack patching for v2 of sample packing"""
|
||||
|
||||
import transformers
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
|
||||
|
||||
|
||||
def patch_for_multipack(model_type):
|
||||
if model_type == "mixtral":
|
||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
elif model_type == "qwen2":
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "falcon":
|
||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "phi":
|
||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
@@ -1,12 +0,0 @@
|
||||
"""
|
||||
Patches to support multipack for phi2
|
||||
"""
|
||||
import transformers
|
||||
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
|
||||
def replace_phi_attn_with_multipack_flash_attn():
|
||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
@@ -1,12 +0,0 @@
|
||||
"""
|
||||
Patches to support multipack for qwen2
|
||||
"""
|
||||
import transformers
|
||||
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
|
||||
def replace_qwen2_attn_with_multipack_flash_attn():
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
@@ -186,8 +186,8 @@ def mask_2d_to_4d(
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
binary_mask = torch.where(
|
||||
mask != 0,
|
||||
torch.tensor(1).to(dtype),
|
||||
torch.tensor(0).to(dtype),
|
||||
torch.tensor(1, device=mask.device).to(dtype),
|
||||
torch.tensor(0, device=mask.device).to(dtype),
|
||||
)
|
||||
|
||||
# Create a block-diagonal mask.
|
||||
|
||||
0
src/axolotl/plugins/oaaic/__init__.py
Normal file
0
src/axolotl/plugins/oaaic/__init__.py
Normal file
0
src/axolotl/plugins/oaaic/data/__init__.py
Normal file
0
src/axolotl/plugins/oaaic/data/__init__.py
Normal file
28
src/axolotl/plugins/oaaic/data/streaming_sql.py
Normal file
28
src/axolotl/plugins/oaaic/data/streaming_sql.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import os
|
||||
from typing import Callable, Generator, Tuple
|
||||
|
||||
import psycopg
|
||||
import psycopg.conninfo
|
||||
|
||||
|
||||
def pgsql(pgsql_table=None, id_field="id", **kwargs) -> Callable:
|
||||
pgsql_conn = os.environ.get("PGSQL_CONN", None)
|
||||
if not pgsql_conn:
|
||||
raise ValueError("missing PGSQL_CONN environment variable")
|
||||
conn_dict = psycopg.conninfo.conninfo_to_dict(pgsql_conn)
|
||||
|
||||
def data_generator() -> Generator[Tuple, None, None]:
|
||||
with psycopg.connect(**conn_dict) as conn:
|
||||
with conn.cursor() as cur:
|
||||
page_size = 10
|
||||
last_id = None
|
||||
while True:
|
||||
if last_id:
|
||||
where_clause = f" WHERE {id_field} > {last_id}"
|
||||
cur.execute(
|
||||
f"SELECT * FROM {pgsql_table}{where_clause} ORDER BY {id_field} ASC LIMIT {page_size}"
|
||||
)
|
||||
for row in cur.fetchall():
|
||||
yield row[id_field], dict(row)
|
||||
|
||||
return data_generator
|
||||
@@ -208,7 +208,10 @@ def train(
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
if not cfg.hub_model_id:
|
||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||
try:
|
||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||
except AttributeError:
|
||||
pass
|
||||
elif cfg.hub_model_id:
|
||||
# defensively push to the hub to ensure the model card is updated
|
||||
trainer.push_to_hub()
|
||||
|
||||
@@ -47,6 +47,12 @@ def gpu_memory_usage_all(device=0):
|
||||
return usage, reserved - usage, max(0, smi - reserved)
|
||||
|
||||
|
||||
def mps_memory_usage_all():
|
||||
usage = torch.mps.current_allocated_memory() / 1024.0**3
|
||||
reserved = torch.mps.driver_allocated_memory() / 1024.0**3
|
||||
return usage, reserved - usage, 0
|
||||
|
||||
|
||||
@check_cuda_device(0.0)
|
||||
def gpu_memory_usage_smi(device=0):
|
||||
if isinstance(device, torch.device):
|
||||
@@ -63,7 +69,10 @@ def gpu_memory_usage_smi(device=0):
|
||||
|
||||
|
||||
def log_gpu_memory_usage(log, msg, device):
|
||||
usage, cache, misc = gpu_memory_usage_all(device)
|
||||
if torch.backends.mps.is_available():
|
||||
usage, cache, misc = mps_memory_usage_all()
|
||||
else:
|
||||
usage, cache, misc = gpu_memory_usage_all(device)
|
||||
extras = []
|
||||
if cache > 0:
|
||||
extras.append(f"+{cache:.03f}GB cache")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Module containing data utilities"""
|
||||
import functools
|
||||
import hashlib
|
||||
import importlib
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
@@ -11,10 +12,12 @@ import yaml
|
||||
from datasets import (
|
||||
Dataset,
|
||||
DatasetDict,
|
||||
IterableDataset,
|
||||
concatenate_datasets,
|
||||
load_dataset,
|
||||
load_from_disk,
|
||||
)
|
||||
from datasets.iterable_dataset import ExamplesIterable
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import HFValidationError
|
||||
from torch.utils.data import RandomSampler
|
||||
@@ -64,6 +67,25 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
||||
|
||||
|
||||
def get_streaming_dataset(ds_cfg):
|
||||
path = ds_cfg["path"]
|
||||
func = None
|
||||
try:
|
||||
load_fn = path.split(".")[-1]
|
||||
module_name = ".".join(load_fn.split(".")[:-1])
|
||||
mod = importlib.import_module(f".{module_name}", "axolotl")
|
||||
func = getattr(mod, load_fn)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if func:
|
||||
data_producer = func(**ds_cfg)
|
||||
return IterableDataset(ExamplesIterable(data_producer, {}))
|
||||
else:
|
||||
split = ds_cfg["split"] or "train"
|
||||
return load_dataset(path, streaming=True, split=split, name=ds_cfg["name"])
|
||||
|
||||
|
||||
def prepare_dataset(cfg, tokenizer):
|
||||
prompters = []
|
||||
if not cfg.pretraining_dataset:
|
||||
@@ -80,14 +102,6 @@ def prepare_dataset(cfg, tokenizer):
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
else:
|
||||
path = cfg.pretraining_dataset
|
||||
name = None
|
||||
if isinstance(cfg.pretraining_dataset, list) and isinstance(
|
||||
cfg.pretraining_dataset[0], dict
|
||||
):
|
||||
path = cfg.pretraining_dataset[0]["path"]
|
||||
name = cfg.pretraining_dataset[0]["name"]
|
||||
|
||||
ds_wrapper_partial = functools.partial(
|
||||
get_dataset_wrapper,
|
||||
cfg.pretraining_dataset[0],
|
||||
@@ -97,7 +111,7 @@ def prepare_dataset(cfg, tokenizer):
|
||||
)
|
||||
|
||||
train_dataset = wrap_pretraining_dataset(
|
||||
load_dataset(path, streaming=True, split="train", name=name),
|
||||
get_streaming_dataset(cfg.pretraining_dataset[0]),
|
||||
tokenizer,
|
||||
cfg,
|
||||
ds_wrapper_partial,
|
||||
@@ -336,6 +350,16 @@ def load_tokenized_prepared_datasets(
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
elif config_dataset.path.startswith("https://"):
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
else:
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
fp = hf_hub_download(
|
||||
|
||||
@@ -29,6 +29,10 @@ from transformers import ( # noqa: F401
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||
patch_for_multipack,
|
||||
)
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
@@ -299,8 +303,15 @@ def load_model(
|
||||
shifted-sparse attention does not currently support sample packing."
|
||||
)
|
||||
|
||||
# Modify all llama derived models in one block
|
||||
if cfg.is_llama_derived_model:
|
||||
if (
|
||||
cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
and cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
patch_for_multipack(cfg.model_config_type)
|
||||
elif cfg.is_llama_derived_model:
|
||||
# Modify all llama derived models in one block
|
||||
|
||||
if cfg.flash_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
replace_llama_attn_with_flash_attn,
|
||||
@@ -354,43 +365,6 @@ def load_model(
|
||||
LOG.info("patching mistral with flash attention")
|
||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||
|
||||
if (
|
||||
cfg.model_config_type == "mixtral"
|
||||
and cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
from axolotl.monkeypatch.mixtral import (
|
||||
replace_mixtral_attn_with_multipack_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching mixtral with flash attention")
|
||||
mixtral_patch_kwargs = {}
|
||||
if is_deepspeed_zero3_enabled():
|
||||
mixtral_patch_kwargs["for_zero3"] = True
|
||||
replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)
|
||||
|
||||
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
|
||||
from axolotl.monkeypatch.falcon import (
|
||||
replace_falcon_attn_with_multipack_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching falcon with flash attention")
|
||||
replace_falcon_attn_with_multipack_flash_attn()
|
||||
|
||||
if cfg.model_config_type == "phi" and cfg.flash_attention and cfg.sample_packing:
|
||||
from axolotl.monkeypatch.phi import replace_phi_attn_with_multipack_flash_attn
|
||||
|
||||
LOG.info("patching phi with flash attention")
|
||||
replace_phi_attn_with_multipack_flash_attn()
|
||||
|
||||
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
|
||||
from axolotl.monkeypatch.qwen2 import (
|
||||
replace_qwen2_attn_with_multipack_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching qwen2 with flash attention")
|
||||
replace_qwen2_attn_with_multipack_flash_attn()
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||
|
||||
@@ -400,7 +374,7 @@ def load_model(
|
||||
model_kwargs: Dict[str, Any] = {}
|
||||
|
||||
if cfg.model_kwargs:
|
||||
for key, val in model_kwargs.items():
|
||||
for key, val in cfg.model_kwargs.items():
|
||||
model_kwargs[key] = val
|
||||
|
||||
max_memory = cfg.max_memory
|
||||
@@ -435,6 +409,10 @@ def load_model(
|
||||
|
||||
model_kwargs["device_map"] = device_map
|
||||
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
model_kwargs["device_map"] = "mps:0"
|
||||
|
||||
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
|
||||
# if cfg.rl:
|
||||
# if torch.cuda.device_count() > 1:
|
||||
@@ -501,7 +479,7 @@ def load_model(
|
||||
"flash_attention_2"
|
||||
)
|
||||
else:
|
||||
if model_config.model_type in ["mixtral", "qwen2", "falcon", "phi"]:
|
||||
if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flash_attention_2"
|
||||
@@ -677,7 +655,7 @@ def load_model(
|
||||
):
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
if hasattr(model, "device") and model.device.type == "cuda":
|
||||
if hasattr(model, "device") and model.device.type in ("cuda", "mps"):
|
||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||
|
||||
# make sure these are fp32 per Ramesh et al. (2021)
|
||||
|
||||
98
ui/main.py
98
ui/main.py
@@ -1,98 +0,0 @@
|
||||
"""
|
||||
This module is used to launch Axolotl with user defined configurations.
|
||||
"""
|
||||
|
||||
import gradio as gr
|
||||
import yaml
|
||||
|
||||
|
||||
def config(
|
||||
base_model,
|
||||
dataset,
|
||||
dataset_type,
|
||||
learn_rate,
|
||||
gradient_accumulation_steps,
|
||||
micro_batch_size,
|
||||
seq_length,
|
||||
num_epochs,
|
||||
output_dir,
|
||||
val_size,
|
||||
):
|
||||
"""
|
||||
This function generates a configuration dictionary and saves it as a yaml file.
|
||||
"""
|
||||
config_dict = {
|
||||
"base_model": base_model,
|
||||
"datasets": [{"path": dataset, "type": dataset_type}],
|
||||
"learning_rate": learn_rate,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"micro_batch_size": micro_batch_size,
|
||||
"sequence_len": seq_length,
|
||||
"num_epochs": num_epochs,
|
||||
"output_dir": output_dir,
|
||||
"val_set_size": val_size,
|
||||
}
|
||||
with open("config.yml", "w", encoding="utf-8") as file:
|
||||
yaml.dump(config_dict, file)
|
||||
print(config_dict)
|
||||
return yaml.dump(config_dict)
|
||||
|
||||
|
||||
with gr.Blocks(title="Axolotl Launcher") as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
# Axolotl Launcher
|
||||
Fill out the required fields below to create a training run.
|
||||
"""
|
||||
)
|
||||
with gr.Row():
|
||||
base_model_name = gr.Textbox(
|
||||
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", label="Base model"
|
||||
)
|
||||
|
||||
mode = gr.Radio(
|
||||
choices=["Full finetune", "QLoRA", "LoRA"],
|
||||
label="Training mode",
|
||||
info="FFT = 16 bit, Qlora = 4 bit, Lora = 8 bit",
|
||||
)
|
||||
with gr.Row():
|
||||
dataset_path = gr.Textbox("mhenrichsen/alpaca_2k_test", label="Dataset")
|
||||
dataset_type_name = gr.Dropdown(
|
||||
choices=["alpaca", "sharegpt"], label="Dataset type", value="alpaca"
|
||||
)
|
||||
with gr.Accordion("Hyperparameters", open=False):
|
||||
gr.Markdown("Choose hyperparameters")
|
||||
with gr.Row():
|
||||
learning_rate = gr.Number(0.000001, label="Learning rate")
|
||||
gradient_accumulation_steps_count = gr.Number(
|
||||
1, label="Gradient accumulation steps"
|
||||
)
|
||||
val_set_size_count = gr.Number(0, label="Validation size")
|
||||
|
||||
with gr.Row():
|
||||
micro_batch_size_count = gr.Number(1, label="Micro batch size")
|
||||
sequence_length = gr.Number(1024, label="Sequence length")
|
||||
num_epochs_count = gr.Number(1, label="Epochs")
|
||||
|
||||
output_dir_path = gr.Textbox("./model-out", label="Output directory")
|
||||
|
||||
create_config = gr.Button("Create config")
|
||||
output = gr.TextArea(label="Generated config")
|
||||
create_config.click(
|
||||
config,
|
||||
inputs=[
|
||||
base_model_name,
|
||||
dataset_path,
|
||||
dataset_type_name,
|
||||
learning_rate,
|
||||
gradient_accumulation_steps_count,
|
||||
micro_batch_size_count,
|
||||
sequence_length,
|
||||
num_epochs_count,
|
||||
output_dir_path,
|
||||
val_set_size_count,
|
||||
],
|
||||
outputs=output,
|
||||
)
|
||||
|
||||
demo.launch(debug=True, server_name="0.0.0.0", server_port=7860)
|
||||
Reference in New Issue
Block a user