Compare commits
12 Commits
feat/space
...
streaming-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e08df47584 | ||
|
|
fac2d98c26 | ||
|
|
ea00dd0852 | ||
|
|
b2a4cb4396 | ||
|
|
aaf54dc730 | ||
|
|
9bca7db133 | ||
|
|
91cf4ee72c | ||
|
|
1daecd161e | ||
|
|
4a654b331e | ||
|
|
5698943263 | ||
|
|
411293bdca | ||
|
|
73f1bdaa15 |
2
.github/workflows/base.yml
vendored
2
.github/workflows/base.yml
vendored
@@ -7,7 +7,7 @@ jobs:
|
|||||||
build-base:
|
build-base:
|
||||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: self-hosted
|
runs-on: axolotl-gpu-runner
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
|
|||||||
18
.github/workflows/main.yml
vendored
18
.github/workflows/main.yml
vendored
@@ -9,7 +9,6 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
build-axolotl:
|
build-axolotl:
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@@ -35,7 +34,7 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -56,27 +55,16 @@ jobs:
|
|||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
load: true
|
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||||
file: ./docker/Dockerfile
|
file: ./docker/Dockerfile
|
||||||
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: |
|
tags: |
|
||||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
- name: Unit Tests
|
|
||||||
run: |
|
|
||||||
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
|
||||||
- name: Push to Docker Hub
|
|
||||||
if: github.event_name != 'pull_request'
|
|
||||||
run: |
|
|
||||||
docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
|
||||||
latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
|
||||||
if [ -n "$latest_tag" ]; then
|
|
||||||
docker push "$latest_tag"
|
|
||||||
fi
|
|
||||||
|
|
||||||
build-axolotl-runpod:
|
build-axolotl-runpod:
|
||||||
needs: build-axolotl
|
needs: build-axolotl
|
||||||
@@ -106,7 +94,7 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
@@ -32,6 +32,9 @@ ignore_missing_imports = True
|
|||||||
[mypy-bitsandbytes]
|
[mypy-bitsandbytes]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[mypy-requests]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
|
||||||
[mypy-datasets]
|
[mypy-datasets]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
|||||||
29
README.md
29
README.md
@@ -25,8 +25,8 @@ Features:
|
|||||||
- [Installation](#installation)
|
- [Installation](#installation)
|
||||||
- [Docker](#docker)
|
- [Docker](#docker)
|
||||||
- [Conda/Pip venv](#condapip-venv)
|
- [Conda/Pip venv](#condapip-venv)
|
||||||
- [Cloud GPU](#cloud-gpu) - Runpod, Latitude
|
- [Cloud GPU](#cloud-gpu) - Latitude.sh, RunPod
|
||||||
- [LambdaLabs](#lambdalabs)
|
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
|
||||||
- [Windows](#windows)
|
- [Windows](#windows)
|
||||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||||
- [Dataset](#dataset)
|
- [Dataset](#dataset)
|
||||||
@@ -121,6 +121,10 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
# gradio
|
# gradio
|
||||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||||
--lora_model_dir="./lora-out" --gradio
|
--lora_model_dir="./lora-out" --gradio
|
||||||
|
|
||||||
|
# remote yaml files - the yaml config can be hosted on a public URL
|
||||||
|
# Note: the yaml config must directly link to the **raw** yaml
|
||||||
|
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/openllama-3b/lora.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
@@ -182,9 +186,13 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
|||||||
|
|
||||||
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
|
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
|
||||||
|
|
||||||
|
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||||
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||||
|
|
||||||
#### LambdaLabs
|
#### Bare Metal Cloud GPU
|
||||||
|
|
||||||
|
##### LambdaLabs
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
|
||||||
<summary>Click to Expand</summary>
|
<summary>Click to Expand</summary>
|
||||||
@@ -464,6 +472,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
dataset:
|
dataset:
|
||||||
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
|
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
|
||||||
...
|
...
|
||||||
|
|
||||||
|
# Loading Data From a Public URL
|
||||||
|
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
|
||||||
|
dataset:
|
||||||
|
- path: https://some.url.com/yourdata.jsonl # The URL should be a direct link to the file you wish to load. URLs must use HTTPS protocol, not HTTP.
|
||||||
|
ds_type: json # this is the default, see other options below.
|
||||||
```
|
```
|
||||||
|
|
||||||
- loading
|
- loading
|
||||||
@@ -976,6 +990,9 @@ Run
|
|||||||
accelerate launch -m axolotl.cli.train your_config.yml
|
accelerate launch -m axolotl.cli.train your_config.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> You can also reference a config file that is hosted on a public URL, for example `accelerate launch -m axolotl.cli.train https://yourdomain.com/your_config.yml`
|
||||||
|
|
||||||
#### Preprocess dataset
|
#### Preprocess dataset
|
||||||
|
|
||||||
You can optionally pre-tokenize dataset with the following before finetuning.
|
You can optionally pre-tokenize dataset with the following before finetuning.
|
||||||
@@ -1200,6 +1217,12 @@ pre-commit install
|
|||||||
pytest tests/
|
pytest tests/
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
|
||||||
|
|
||||||
|
<a href="https://github.com/openaccess-ai-collective/axolotl/graphs/contributors">
|
||||||
|
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
|
||||||
|
</a>
|
||||||
|
|
||||||
## Sponsors 🤝❤
|
## Sponsors 🤝❤
|
||||||
|
|
||||||
OpenAccess AI Collective is run by volunteer contributors such as [winglian](https://github.com/winglian),
|
OpenAccess AI Collective is run by volunteer contributors such as [winglian](https://github.com/winglian),
|
||||||
|
|||||||
65
examples/tiny-llama/lora-mps.yml
Normal file
65
examples/tiny-llama/lora-mps.yml
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
is_llama_derived_model: true
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./lora-out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
eval_sample_packing: false
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16: false
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: false
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 0
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
@@ -10,9 +10,9 @@ strict: false
|
|||||||
|
|
||||||
max_steps: 200
|
max_steps: 200
|
||||||
pretraining_dataset:
|
pretraining_dataset:
|
||||||
path: c4
|
- path: c4
|
||||||
name: en
|
name: en
|
||||||
type: pretrain
|
type: pretrain
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./model-out
|
output_dir: ./model-out
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
pre-commit
|
pre-commit
|
||||||
black
|
black
|
||||||
mypy
|
mypy
|
||||||
|
types-requests
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ deepspeed>=0.13.1
|
|||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
|
requests
|
||||||
datasets>=2.15.0
|
datasets>=2.15.0
|
||||||
flash-attn==2.3.3
|
flash-attn==2.3.3
|
||||||
sentencepiece
|
sentencepiece
|
||||||
|
|||||||
24
setup.py
24
setup.py
@@ -1,5 +1,7 @@
|
|||||||
"""setup.py for axolotl"""
|
"""setup.py for axolotl"""
|
||||||
|
|
||||||
|
import platform
|
||||||
|
import re
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
@@ -26,11 +28,25 @@ def parse_requirements():
|
|||||||
_install_requires.append(line)
|
_install_requires.append(line)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
torch_version = version("torch")
|
if "Darwin" in platform.system():
|
||||||
_install_requires.append(f"torch=={torch_version}")
|
|
||||||
if torch_version.startswith("2.1."):
|
|
||||||
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
||||||
_install_requires.append("xformers>=0.0.23")
|
else:
|
||||||
|
torch_version = version("torch")
|
||||||
|
_install_requires.append(f"torch=={torch_version}")
|
||||||
|
|
||||||
|
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||||
|
if version_match:
|
||||||
|
major, minor, patch = version_match.groups()
|
||||||
|
major, minor = int(major), int(minor)
|
||||||
|
patch = (
|
||||||
|
int(patch) if patch is not None else 0
|
||||||
|
) # Default patch to 0 if not present
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
|
if (major, minor) >= (2, 1):
|
||||||
|
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
||||||
|
_install_requires.append("xformers>=0.0.23")
|
||||||
except PackageNotFoundError:
|
except PackageNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,20 @@
|
|||||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@@ -59,6 +63,52 @@ def print_axolotl_text_art(suffix=None):
|
|||||||
print(ascii_art)
|
print(ascii_art)
|
||||||
|
|
||||||
|
|
||||||
|
def check_remote_config(config: Union[str, Path]):
|
||||||
|
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
|
||||||
|
if not (isinstance(config, str) and config.startswith("https://")):
|
||||||
|
return config # Return the original value if it's not a valid URL
|
||||||
|
|
||||||
|
filename = os.path.basename(urlparse(config).path)
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(config, timeout=30)
|
||||||
|
response.raise_for_status() # Check for HTTP errors
|
||||||
|
|
||||||
|
content = response.content
|
||||||
|
try:
|
||||||
|
# Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML
|
||||||
|
json.loads(content)
|
||||||
|
# Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link
|
||||||
|
LOG.warning(
|
||||||
|
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If it's not valid JSON, verify it's valid YAML
|
||||||
|
try:
|
||||||
|
yaml.safe_load(content)
|
||||||
|
except yaml.YAMLError as err:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to parse the content at {config} as YAML: {err}"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
# Write the content to a file if it's valid YAML (or JSON treated as YAML)
|
||||||
|
output_path = Path(temp_dir) / filename
|
||||||
|
with open(output_path, "wb") as file:
|
||||||
|
file.write(content)
|
||||||
|
LOG.info(
|
||||||
|
f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n"
|
||||||
|
)
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
except requests.RequestException as err:
|
||||||
|
# This catches all requests-related exceptions including HTTPError
|
||||||
|
raise RuntimeError(f"Failed to download {config}: {err}") from err
|
||||||
|
except Exception as err:
|
||||||
|
# Catch-all for any other exceptions
|
||||||
|
raise err
|
||||||
|
|
||||||
|
|
||||||
def get_multi_line_input() -> Optional[str]:
|
def get_multi_line_input() -> Optional[str]:
|
||||||
print("Give me an instruction (Ctrl + D to submit): ")
|
print("Give me an instruction (Ctrl + D to submit): ")
|
||||||
instruction = ""
|
instruction = ""
|
||||||
@@ -270,9 +320,10 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
|
|||||||
return not any(el in list2 for el in list1)
|
return not any(el in list2 for el in list1)
|
||||||
|
|
||||||
|
|
||||||
def load_cfg(config: Path = Path("examples/"), **kwargs):
|
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||||
|
config = check_remote_config(config)
|
||||||
if Path(config).is_dir():
|
if Path(config).is_dir():
|
||||||
config = choose_config(config)
|
config = choose_config(Path(config))
|
||||||
|
|
||||||
# load the config from the yaml file
|
# load the config from the yaml file
|
||||||
with open(config, encoding="utf-8") as file:
|
with open(config, encoding="utf-8") as file:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ CLI to run training on a model
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
@@ -23,7 +24,7 @@ from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
|||||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ CLI to shard a trained model into 10GiB chunks
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
@@ -25,7 +26,7 @@ def shard(
|
|||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ CLI to run training on a model
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
@@ -25,7 +25,7 @@ from axolotl.train import train
|
|||||||
LOG = logging.getLogger("axolotl.cli.train")
|
LOG = logging.getLogger("axolotl.cli.train")
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
parser = HfArgumentParser((TrainerCliArgs))
|
parser = HfArgumentParser((TrainerCliArgs))
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from transformers import (
|
|||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
@@ -994,7 +995,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
if use_batch_sampler_collator:
|
if use_batch_sampler_collator:
|
||||||
if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
|
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
elif (
|
elif (
|
||||||
self.cfg.model_config_type in ["llama"]
|
self.cfg.model_config_type in ["llama"]
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
Patches to support multipack for falcon
|
|
||||||
"""
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_unpad_data
|
|
||||||
|
|
||||||
|
|
||||||
def replace_falcon_attn_with_multipack_flash_attn():
|
|
||||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
@@ -2,9 +2,6 @@
|
|||||||
Patches to support multipack for mixtral
|
Patches to support multipack for mixtral
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_unpad_data
|
|
||||||
|
|
||||||
|
|
||||||
def patch_mixtral_moe_forward_zero3() -> None:
|
def patch_mixtral_moe_forward_zero3() -> None:
|
||||||
@@ -51,11 +48,3 @@ def patch_mixtral_moe_forward_zero3() -> None:
|
|||||||
|
|
||||||
MixtralBLockSparseTop2MLP.forward = mlp_forward
|
MixtralBLockSparseTop2MLP.forward = mlp_forward
|
||||||
MixtralSparseMoeBlock.forward = moe_forward
|
MixtralSparseMoeBlock.forward = moe_forward
|
||||||
|
|
||||||
|
|
||||||
def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
|
|
||||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
if for_zero3:
|
|
||||||
patch_mixtral_moe_forward_zero3()
|
|
||||||
|
|||||||
30
src/axolotl/monkeypatch/multipack.py
Normal file
30
src/axolotl/monkeypatch/multipack.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""multipack patching for v2 of sample packing"""
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
||||||
|
from axolotl.monkeypatch.utils import get_unpad_data
|
||||||
|
|
||||||
|
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
|
||||||
|
|
||||||
|
|
||||||
|
def patch_for_multipack(model_type):
|
||||||
|
if model_type == "mixtral":
|
||||||
|
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
patch_mixtral_moe_forward_zero3()
|
||||||
|
elif model_type == "qwen2":
|
||||||
|
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "falcon":
|
||||||
|
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "phi":
|
||||||
|
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
Patches to support multipack for phi2
|
|
||||||
"""
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_unpad_data
|
|
||||||
|
|
||||||
|
|
||||||
def replace_phi_attn_with_multipack_flash_attn():
|
|
||||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
Patches to support multipack for qwen2
|
|
||||||
"""
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_unpad_data
|
|
||||||
|
|
||||||
|
|
||||||
def replace_qwen2_attn_with_multipack_flash_attn():
|
|
||||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
@@ -186,8 +186,8 @@ def mask_2d_to_4d(
|
|||||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||||
binary_mask = torch.where(
|
binary_mask = torch.where(
|
||||||
mask != 0,
|
mask != 0,
|
||||||
torch.tensor(1).to(dtype),
|
torch.tensor(1, device=mask.device).to(dtype),
|
||||||
torch.tensor(0).to(dtype),
|
torch.tensor(0, device=mask.device).to(dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a block-diagonal mask.
|
# Create a block-diagonal mask.
|
||||||
|
|||||||
0
src/axolotl/plugins/oaaic/__init__.py
Normal file
0
src/axolotl/plugins/oaaic/__init__.py
Normal file
0
src/axolotl/plugins/oaaic/data/__init__.py
Normal file
0
src/axolotl/plugins/oaaic/data/__init__.py
Normal file
28
src/axolotl/plugins/oaaic/data/streaming_sql.py
Normal file
28
src/axolotl/plugins/oaaic/data/streaming_sql.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import os
|
||||||
|
from typing import Callable, Generator, Tuple
|
||||||
|
|
||||||
|
import psycopg
|
||||||
|
import psycopg.conninfo
|
||||||
|
|
||||||
|
|
||||||
|
def pgsql(pgsql_table=None, id_field="id", **kwargs) -> Callable:
|
||||||
|
pgsql_conn = os.environ.get("PGSQL_CONN", None)
|
||||||
|
if not pgsql_conn:
|
||||||
|
raise ValueError("missing PGSQL_CONN environment variable")
|
||||||
|
conn_dict = psycopg.conninfo.conninfo_to_dict(pgsql_conn)
|
||||||
|
|
||||||
|
def data_generator() -> Generator[Tuple, None, None]:
|
||||||
|
with psycopg.connect(**conn_dict) as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
page_size = 10
|
||||||
|
last_id = None
|
||||||
|
while True:
|
||||||
|
if last_id:
|
||||||
|
where_clause = f" WHERE {id_field} > {last_id}"
|
||||||
|
cur.execute(
|
||||||
|
f"SELECT * FROM {pgsql_table}{where_clause} ORDER BY {id_field} ASC LIMIT {page_size}"
|
||||||
|
)
|
||||||
|
for row in cur.fetchall():
|
||||||
|
yield row[id_field], dict(row)
|
||||||
|
|
||||||
|
return data_generator
|
||||||
@@ -208,7 +208,10 @@ def train(
|
|||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
if not cfg.hub_model_id:
|
if not cfg.hub_model_id:
|
||||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
try:
|
||||||
|
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
elif cfg.hub_model_id:
|
elif cfg.hub_model_id:
|
||||||
# defensively push to the hub to ensure the model card is updated
|
# defensively push to the hub to ensure the model card is updated
|
||||||
trainer.push_to_hub()
|
trainer.push_to_hub()
|
||||||
|
|||||||
@@ -47,6 +47,12 @@ def gpu_memory_usage_all(device=0):
|
|||||||
return usage, reserved - usage, max(0, smi - reserved)
|
return usage, reserved - usage, max(0, smi - reserved)
|
||||||
|
|
||||||
|
|
||||||
|
def mps_memory_usage_all():
|
||||||
|
usage = torch.mps.current_allocated_memory() / 1024.0**3
|
||||||
|
reserved = torch.mps.driver_allocated_memory() / 1024.0**3
|
||||||
|
return usage, reserved - usage, 0
|
||||||
|
|
||||||
|
|
||||||
@check_cuda_device(0.0)
|
@check_cuda_device(0.0)
|
||||||
def gpu_memory_usage_smi(device=0):
|
def gpu_memory_usage_smi(device=0):
|
||||||
if isinstance(device, torch.device):
|
if isinstance(device, torch.device):
|
||||||
@@ -63,7 +69,10 @@ def gpu_memory_usage_smi(device=0):
|
|||||||
|
|
||||||
|
|
||||||
def log_gpu_memory_usage(log, msg, device):
|
def log_gpu_memory_usage(log, msg, device):
|
||||||
usage, cache, misc = gpu_memory_usage_all(device)
|
if torch.backends.mps.is_available():
|
||||||
|
usage, cache, misc = mps_memory_usage_all()
|
||||||
|
else:
|
||||||
|
usage, cache, misc = gpu_memory_usage_all(device)
|
||||||
extras = []
|
extras = []
|
||||||
if cache > 0:
|
if cache > 0:
|
||||||
extras.append(f"+{cache:.03f}GB cache")
|
extras.append(f"+{cache:.03f}GB cache")
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Module containing data utilities"""
|
"""Module containing data utilities"""
|
||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -11,10 +12,12 @@ import yaml
|
|||||||
from datasets import (
|
from datasets import (
|
||||||
Dataset,
|
Dataset,
|
||||||
DatasetDict,
|
DatasetDict,
|
||||||
|
IterableDataset,
|
||||||
concatenate_datasets,
|
concatenate_datasets,
|
||||||
load_dataset,
|
load_dataset,
|
||||||
load_from_disk,
|
load_from_disk,
|
||||||
)
|
)
|
||||||
|
from datasets.iterable_dataset import ExamplesIterable
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from huggingface_hub.utils import HFValidationError
|
from huggingface_hub.utils import HFValidationError
|
||||||
from torch.utils.data import RandomSampler
|
from torch.utils.data import RandomSampler
|
||||||
@@ -64,6 +67,25 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
|||||||
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
||||||
|
|
||||||
|
|
||||||
|
def get_streaming_dataset(ds_cfg):
|
||||||
|
path = ds_cfg["path"]
|
||||||
|
func = None
|
||||||
|
try:
|
||||||
|
load_fn = path.split(".")[-1]
|
||||||
|
module_name = ".".join(load_fn.split(".")[:-1])
|
||||||
|
mod = importlib.import_module(f".{module_name}", "axolotl")
|
||||||
|
func = getattr(mod, load_fn)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if func:
|
||||||
|
data_producer = func(**ds_cfg)
|
||||||
|
return IterableDataset(ExamplesIterable(data_producer, {}))
|
||||||
|
else:
|
||||||
|
split = ds_cfg["split"] or "train"
|
||||||
|
return load_dataset(path, streaming=True, split=split, name=ds_cfg["name"])
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataset(cfg, tokenizer):
|
def prepare_dataset(cfg, tokenizer):
|
||||||
prompters = []
|
prompters = []
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
@@ -80,14 +102,6 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
path = cfg.pretraining_dataset
|
|
||||||
name = None
|
|
||||||
if isinstance(cfg.pretraining_dataset, list) and isinstance(
|
|
||||||
cfg.pretraining_dataset[0], dict
|
|
||||||
):
|
|
||||||
path = cfg.pretraining_dataset[0]["path"]
|
|
||||||
name = cfg.pretraining_dataset[0]["name"]
|
|
||||||
|
|
||||||
ds_wrapper_partial = functools.partial(
|
ds_wrapper_partial = functools.partial(
|
||||||
get_dataset_wrapper,
|
get_dataset_wrapper,
|
||||||
cfg.pretraining_dataset[0],
|
cfg.pretraining_dataset[0],
|
||||||
@@ -97,7 +111,7 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = wrap_pretraining_dataset(
|
train_dataset = wrap_pretraining_dataset(
|
||||||
load_dataset(path, streaming=True, split="train", name=name),
|
get_streaming_dataset(cfg.pretraining_dataset[0]),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg,
|
cfg,
|
||||||
ds_wrapper_partial,
|
ds_wrapper_partial,
|
||||||
@@ -336,6 +350,16 @@ def load_tokenized_prepared_datasets(
|
|||||||
split=None,
|
split=None,
|
||||||
storage_options=storage_options,
|
storage_options=storage_options,
|
||||||
)
|
)
|
||||||
|
elif config_dataset.path.startswith("https://"):
|
||||||
|
ds_type = get_ds_type(config_dataset)
|
||||||
|
ds = load_dataset(
|
||||||
|
ds_type,
|
||||||
|
name=config_dataset.name,
|
||||||
|
data_files=config_dataset.path,
|
||||||
|
streaming=False,
|
||||||
|
split=None,
|
||||||
|
storage_options=storage_options,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if isinstance(config_dataset.data_files, str):
|
if isinstance(config_dataset.data_files, str):
|
||||||
fp = hf_hub_download(
|
fp = hf_hub_download(
|
||||||
|
|||||||
@@ -29,6 +29,10 @@ from transformers import ( # noqa: F401
|
|||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
|
from axolotl.monkeypatch.multipack import (
|
||||||
|
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||||
|
patch_for_multipack,
|
||||||
|
)
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
@@ -299,8 +303,15 @@ def load_model(
|
|||||||
shifted-sparse attention does not currently support sample packing."
|
shifted-sparse attention does not currently support sample packing."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Modify all llama derived models in one block
|
if (
|
||||||
if cfg.is_llama_derived_model:
|
cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
|
and cfg.flash_attention
|
||||||
|
and cfg.sample_packing
|
||||||
|
):
|
||||||
|
patch_for_multipack(cfg.model_config_type)
|
||||||
|
elif cfg.is_llama_derived_model:
|
||||||
|
# Modify all llama derived models in one block
|
||||||
|
|
||||||
if cfg.flash_attention:
|
if cfg.flash_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
replace_llama_attn_with_flash_attn,
|
replace_llama_attn_with_flash_attn,
|
||||||
@@ -354,43 +365,6 @@ def load_model(
|
|||||||
LOG.info("patching mistral with flash attention")
|
LOG.info("patching mistral with flash attention")
|
||||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||||
|
|
||||||
if (
|
|
||||||
cfg.model_config_type == "mixtral"
|
|
||||||
and cfg.flash_attention
|
|
||||||
and cfg.sample_packing
|
|
||||||
):
|
|
||||||
from axolotl.monkeypatch.mixtral import (
|
|
||||||
replace_mixtral_attn_with_multipack_flash_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("patching mixtral with flash attention")
|
|
||||||
mixtral_patch_kwargs = {}
|
|
||||||
if is_deepspeed_zero3_enabled():
|
|
||||||
mixtral_patch_kwargs["for_zero3"] = True
|
|
||||||
replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)
|
|
||||||
|
|
||||||
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
|
|
||||||
from axolotl.monkeypatch.falcon import (
|
|
||||||
replace_falcon_attn_with_multipack_flash_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("patching falcon with flash attention")
|
|
||||||
replace_falcon_attn_with_multipack_flash_attn()
|
|
||||||
|
|
||||||
if cfg.model_config_type == "phi" and cfg.flash_attention and cfg.sample_packing:
|
|
||||||
from axolotl.monkeypatch.phi import replace_phi_attn_with_multipack_flash_attn
|
|
||||||
|
|
||||||
LOG.info("patching phi with flash attention")
|
|
||||||
replace_phi_attn_with_multipack_flash_attn()
|
|
||||||
|
|
||||||
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
|
|
||||||
from axolotl.monkeypatch.qwen2 import (
|
|
||||||
replace_qwen2_attn_with_multipack_flash_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("patching qwen2 with flash attention")
|
|
||||||
replace_qwen2_attn_with_multipack_flash_attn()
|
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
||||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||||
|
|
||||||
@@ -400,7 +374,7 @@ def load_model(
|
|||||||
model_kwargs: Dict[str, Any] = {}
|
model_kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
if cfg.model_kwargs:
|
if cfg.model_kwargs:
|
||||||
for key, val in model_kwargs.items():
|
for key, val in cfg.model_kwargs.items():
|
||||||
model_kwargs[key] = val
|
model_kwargs[key] = val
|
||||||
|
|
||||||
max_memory = cfg.max_memory
|
max_memory = cfg.max_memory
|
||||||
@@ -435,6 +409,10 @@ def load_model(
|
|||||||
|
|
||||||
model_kwargs["device_map"] = device_map
|
model_kwargs["device_map"] = device_map
|
||||||
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
||||||
|
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
model_kwargs["device_map"] = "mps:0"
|
||||||
|
|
||||||
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
|
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
|
||||||
# if cfg.rl:
|
# if cfg.rl:
|
||||||
# if torch.cuda.device_count() > 1:
|
# if torch.cuda.device_count() > 1:
|
||||||
@@ -501,7 +479,7 @@ def load_model(
|
|||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if model_config.model_type in ["mixtral", "qwen2", "falcon", "phi"]:
|
if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
@@ -677,7 +655,7 @@ def load_model(
|
|||||||
):
|
):
|
||||||
model.config.eos_token_id = tokenizer.eos_token_id
|
model.config.eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
if hasattr(model, "device") and model.device.type == "cuda":
|
if hasattr(model, "device") and model.device.type in ("cuda", "mps"):
|
||||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||||
|
|
||||||
# make sure these are fp32 per Ramesh et al. (2021)
|
# make sure these are fp32 per Ramesh et al. (2021)
|
||||||
|
|||||||
98
ui/main.py
98
ui/main.py
@@ -1,98 +0,0 @@
|
|||||||
"""
|
|
||||||
This module is used to launch Axolotl with user defined configurations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
|
|
||||||
def config(
|
|
||||||
base_model,
|
|
||||||
dataset,
|
|
||||||
dataset_type,
|
|
||||||
learn_rate,
|
|
||||||
gradient_accumulation_steps,
|
|
||||||
micro_batch_size,
|
|
||||||
seq_length,
|
|
||||||
num_epochs,
|
|
||||||
output_dir,
|
|
||||||
val_size,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
This function generates a configuration dictionary and saves it as a yaml file.
|
|
||||||
"""
|
|
||||||
config_dict = {
|
|
||||||
"base_model": base_model,
|
|
||||||
"datasets": [{"path": dataset, "type": dataset_type}],
|
|
||||||
"learning_rate": learn_rate,
|
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
|
||||||
"micro_batch_size": micro_batch_size,
|
|
||||||
"sequence_len": seq_length,
|
|
||||||
"num_epochs": num_epochs,
|
|
||||||
"output_dir": output_dir,
|
|
||||||
"val_set_size": val_size,
|
|
||||||
}
|
|
||||||
with open("config.yml", "w", encoding="utf-8") as file:
|
|
||||||
yaml.dump(config_dict, file)
|
|
||||||
print(config_dict)
|
|
||||||
return yaml.dump(config_dict)
|
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks(title="Axolotl Launcher") as demo:
|
|
||||||
gr.Markdown(
|
|
||||||
"""
|
|
||||||
# Axolotl Launcher
|
|
||||||
Fill out the required fields below to create a training run.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
base_model_name = gr.Textbox(
|
|
||||||
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", label="Base model"
|
|
||||||
)
|
|
||||||
|
|
||||||
mode = gr.Radio(
|
|
||||||
choices=["Full finetune", "QLoRA", "LoRA"],
|
|
||||||
label="Training mode",
|
|
||||||
info="FFT = 16 bit, Qlora = 4 bit, Lora = 8 bit",
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
dataset_path = gr.Textbox("mhenrichsen/alpaca_2k_test", label="Dataset")
|
|
||||||
dataset_type_name = gr.Dropdown(
|
|
||||||
choices=["alpaca", "sharegpt"], label="Dataset type", value="alpaca"
|
|
||||||
)
|
|
||||||
with gr.Accordion("Hyperparameters", open=False):
|
|
||||||
gr.Markdown("Choose hyperparameters")
|
|
||||||
with gr.Row():
|
|
||||||
learning_rate = gr.Number(0.000001, label="Learning rate")
|
|
||||||
gradient_accumulation_steps_count = gr.Number(
|
|
||||||
1, label="Gradient accumulation steps"
|
|
||||||
)
|
|
||||||
val_set_size_count = gr.Number(0, label="Validation size")
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
micro_batch_size_count = gr.Number(1, label="Micro batch size")
|
|
||||||
sequence_length = gr.Number(1024, label="Sequence length")
|
|
||||||
num_epochs_count = gr.Number(1, label="Epochs")
|
|
||||||
|
|
||||||
output_dir_path = gr.Textbox("./model-out", label="Output directory")
|
|
||||||
|
|
||||||
create_config = gr.Button("Create config")
|
|
||||||
output = gr.TextArea(label="Generated config")
|
|
||||||
create_config.click(
|
|
||||||
config,
|
|
||||||
inputs=[
|
|
||||||
base_model_name,
|
|
||||||
dataset_path,
|
|
||||||
dataset_type_name,
|
|
||||||
learning_rate,
|
|
||||||
gradient_accumulation_steps_count,
|
|
||||||
micro_batch_size_count,
|
|
||||||
sequence_length,
|
|
||||||
num_epochs_count,
|
|
||||||
output_dir_path,
|
|
||||||
val_set_size_count,
|
|
||||||
],
|
|
||||||
outputs=output,
|
|
||||||
)
|
|
||||||
|
|
||||||
demo.launch(debug=True, server_name="0.0.0.0", server_port=7860)
|
|
||||||
Reference in New Issue
Block a user