Compare commits

..

9 Commits

Author SHA1 Message Date
Wing Lian
587dbbfc02 fix the patch to work properly and work with FSDP
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-08-17 20:33:56 -04:00
Wing Lian
6c306d9186 enable eval dataloader using multipack again 2023-08-17 18:17:31 -04:00
Wing Lian
7565fb9d63 update forwards so we only calculate cu_seqlens once 2023-08-17 18:17:31 -04:00
Wing Lian
a6b737d5ff copy LlamaModel.forward and LlamaDecoderLayer.forward into monkeypatch 2023-08-17 18:17:30 -04:00
Wing Lian
cf95b57c0a fix patch to check position ids and don't use multipack for evals 2023-08-17 18:17:30 -04:00
Aman Karmani
13f7efaf74 speed up flash-attn inference 2023-08-17 18:17:30 -04:00
Aman Karmani
d773384f74 update flash-attn patch for 70B/GQA and inference using helper from flash-attn tests 2023-08-17 18:17:30 -04:00
Aman Karmani
985dcbc051 sync xformers patch to follow shared format and be diffable 2023-08-17 18:17:30 -04:00
Aman Karmani
5d0b27e5a1 split sdp attn into its own patch 2023-08-17 18:17:30 -04:00
24 changed files with 187 additions and 432 deletions

View File

@@ -13,17 +13,17 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 118
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.1
@@ -49,11 +49,10 @@ jobs:
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-runpod:
needs: build-axolotl

View File

@@ -24,7 +24,7 @@ jobs:
- name: Install dependencies
run: |
pip install -e .[peft]
pip install -e .
pip install -r requirements-tests.txt
- name: Run tests

View File

@@ -16,7 +16,6 @@ Axolotl is a tool designed to streamline the fine-tuning of various AI models, o
- [LambdaLabs Installation](#lambdalabs)
- [Dataset](#dataset)
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
- [Config](#config)
- [Train](#train)
- [Inference](#inference)
@@ -69,9 +68,8 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install -e .[flash-attn]
pip3 install -e .
pip3 install -U git+https://github.com/huggingface/peft.git
# finetune lora
@@ -100,7 +98,7 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
```
- Conda/Pip venv
1. Install python >=**3.9**
1. Install python **3.9**
2. Install pytorch stable https://pytorch.org/get-started/locally/
@@ -153,7 +151,9 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
pip3 install -e . # change depend on needs
pip3 install protobuf==3.20.3
pip3 install -U --ignore-installed requests Pillow psutil scipy
pip3 install -U requests
pip3 install -U --ignore-installed psutil
pip3 install -U scipy
pip3 install git+https://github.com/huggingface/peft.git # not for gptq
```
@@ -257,10 +257,6 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- `metharme`: instruction, adds additional eos tokens
```json
{"prompt": "...", "generation": "..."}
```
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
```json
{"conversations": [{"role": "...", "value": "..."}]}
@@ -278,29 +274,11 @@ Have dataset(s) in one of the following format (JSONL recommended):
#### How to add custom prompts
Using yaml. Example:
```yaml
datasets:
- path: repo
type:
system_prompt: ""
no_input_format: |-
User: {instruction}<|end_of_turn|>
Assistant:
format: |-
User: {instruction}
{input}<|end_of_turn|>
Assistant:
```
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
Using file:
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
Optionally, download some datasets, see [data/README.md](data/README.md)
#### How to use your custom pretokenized dataset
- Do not pass a `type:`
- Dataset must contain `input_ids`, `attention_mask`, `labels` in columns
### Config
@@ -330,9 +308,9 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
# local
datasets:
- path: data.jsonl # or json
ds_type: json # see other options below
type: alpaca
- path: json
data_files: data.jsonl # or json
type: alpaca # format from earlier
```
- loading
@@ -413,29 +391,10 @@ datasets:
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet) defines the datatype when path is a file
data_files: # path to source data files
shards: # number of shards to split data into
name: # name of dataset configuration to load
# custom user prompt
- path: repo
type:
# the below are defaults. only set what's needed.
system_prompt: ""
field_system: system
field_instruction: instruction
field_output: input
# customizable to be single line or multi-line
system_format: "{system}"
# 'format' can include {input}
format: |-
User: {instruction} {input}
Assistant:
# 'no_input_format' cannot include {input}
no_input_format: "{instruction} "
# axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
@@ -513,7 +472,6 @@ warmup_steps: 100
learning_rate: 0.00003
lr_quadratic_warmup:
logging_steps:
save_strategy: # set to `no` to skip checkpoint saves
save_steps: # leave empty to save at each epoch
eval_steps:
save_total_limit: # checkpoints saved at a time
@@ -708,9 +666,7 @@ Please reduce any below
- `gradient_accumulation_steps`
- `sequence_len`
> `failed (exitcode: -9)`
Usually means your system has run out of system memory.
> `failed (exitcode: -9)` usually means your system has run out of system memory.
Similarly, you should consider reducing the same settings as when you run out of VRAM.
Additionally, look into upgrading your system RAM which should be simpler than GPU upgrades.

24
data/README.md Normal file
View File

@@ -0,0 +1,24 @@
## Download some datasets
```shell
curl https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_gpt4.json -o data/raw/alpaca_data_gpt4.json
curl https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -L -o data/raw/vicuna_cleaned.json
curl https://github.com/teknium1/GPTeacher/blob/main/Instruct/gpt4-instruct-similarity-0.6-dataset.json?raw=true -L -o data/raw/gpt4-instruct-similarity-0.6-dataset.json
curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarity_0.6-instruct-dataset.json?raw=true -L -o data/raw/roleplay-similarity_0.6-instruct-dataset.json
```
## Convert the JSON data files to JSONL.
```shell
python3 ./scripts/alpaca_json_to_jsonl.py --file data/alpaca_data_gpt4.json --output data/alpaca_data_gpt4.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/vicuna_cleaned.json --output data/vicuna_cleaned.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/roleplay-similarity_0.6-instruct-dataset.json --output data/roleplay-similarity_0.6-instruct-dataset.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/gpt4-instruct-similarity-0.6-dataset.json --output data/gpt4-instruct-similarity-0.6-dataset.jsonl
```
---
Using JSONL makes it easier to subset the data if you want a smaller training set, i.e get 2000 random examples.
```shell
shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl
```

1
data/raw/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
**

View File

@@ -16,9 +16,9 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN cd axolotl && \
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
pip install -e .[$AXOLOTL_EXTRAS]; \
else \
pip install -e .[flash-attn]; \
pip install -e .; \
fi
# fix so that git fetch/pull from remote works

View File

@@ -31,6 +31,26 @@ WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
FROM base-builder AS flash-attn-builder
WORKDIR /workspace
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
cd flash-attention && \
git checkout v2.0.4 && \
python3 setup.py bdist_wheel && \
cd csrc/fused_dense_lib && \
python3 setup.py bdist_wheel && \
cd ../xentropy && \
python3 setup.py bdist_wheel && \
cd ../rotary && \
python3 setup.py bdist_wheel && \
cd ../layer_norm && \
python3 setup.py bdist_wheel
FROM base-builder AS deepspeed-builder
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
@@ -70,8 +90,13 @@ RUN mkdir -p /workspace/wheels/bitsandbytes
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/fused_dense_lib/dist/fused_dense_lib-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/xentropy/dist/xentropy_cuda_lib-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/rotary/dist/rotary_emb-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/dropout_layer_norm-*.whl wheels
RUN pip3 install wheels/deepspeed-*.whl
RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl wheels/fused_dense_lib-*.whl wheels/xentropy_cuda_lib-*.whl wheels/rotary_emb-*.whl wheels/dropout_layer_norm-*.whl
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
RUN git lfs install --skip-repo
RUN pip3 install awscli && \

View File

@@ -6,14 +6,13 @@ addict
fire
PyYAML==6.0
datasets
flash-attn==2.0.8
accelerate>=0.19.0
sentencepiece
wandb
einops
xformers
optimum
hf_transfer
colorama
numba
numpy==1.24.4
# qlora things

View File

@@ -0,0 +1,52 @@
"""Module to convert json file to jsonl"""
import os
import sys
from pathlib import Path
from typing import Optional, Union
import fire
from axolotl.convert import (
FileReader,
FileWriter,
JsonlSerializer,
JsonParser,
JsonToJsonlConverter,
StdoutWriter,
)
from axolotl.logging_config import configure_logging
configure_logging()
# add src to the pythonpath so we don't need to pip install this
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
def main(
file: Path,
output: Optional[Path] = None,
to_stdout: Optional[bool] = False,
):
"""
Convert a json file to jsonl
"""
file_reader = FileReader()
writer: Union[StdoutWriter, FileWriter]
if to_stdout or output is None:
writer = StdoutWriter()
else:
writer = FileWriter(output)
json_parser = JsonParser()
jsonl_serializer = JsonlSerializer()
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
converter.convert(file, output)
if __name__ == "__main__":
fire.Fire(main)

View File

@@ -155,23 +155,6 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
return not any(el in list2 for el in list1)
def merge_lora(model, tokenizer, cfg):
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model_dtype = torch.bfloat16 if cfg.bf16 or cfg.bfloat16 else torch.float16
model.to(dtype=model_dtype)
if cfg.hub_model_id:
model.push_to_hub("hub_model_id")
if cfg.local_rank == 0:
LOG.info("saving merged model")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=cfg.save_safetensors is True,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def train(
config: Path = Path("configs/"),
prepare_ds_only: bool = False,
@@ -231,7 +214,17 @@ def train(
safe_serialization = cfg.save_safetensors is True
if "merge_lora" in kwargs and cfg.adapter is not None:
merge_lora(model, tokenizer, cfg)
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info("saving merged model")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return
if cfg.inference:
@@ -317,9 +310,6 @@ def train(
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if cfg.adapter is not None:
merge_lora(model, tokenizer, cfg)
if __name__ == "__main__":
fire.Fire(train)

View File

@@ -7,7 +7,6 @@ with open("./requirements.txt", encoding="utf-8") as requirements_file:
# don't include peft yet until we check the int4
# need to manually install peft for now...
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
reqs = [r for r in reqs if "flash-attn" not in r]
reqs = [r for r in reqs if r and r[0] != "#"]
for r in reqs:
install_requires.append(r)
@@ -26,14 +25,9 @@ setup(
"gptq_triton": [
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
],
"flash-attn": [
"flash-attn==2.0.8",
],
"extras": [
"flash-attn",
"deepspeed",
],
"peft": [
"peft @ git+https://github.com/huggingface/peft.git",
],
},
)

View File

@@ -1,42 +1,16 @@
"""
Common logging module for axolotl
"""
"""Logging configuration settings"""
import os
import sys
from logging import Formatter
from logging.config import dictConfig
from typing import Any, Dict
from colorama import Fore, Style, init
class ColorfulFormatter(Formatter):
"""
Formatter to add coloring to log messages by log type
"""
COLORS = {
"WARNING": Fore.YELLOW,
"ERROR": Fore.RED,
"CRITICAL": Fore.RED + Style.BRIGHT,
}
def format(self, record):
log_message = super().format(record)
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
"version": 1,
"formatters": {
"simple": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
},
"colorful": {
"()": ColorfulFormatter,
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
},
},
"filters": {},
"handlers": {
@@ -46,25 +20,14 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
"filters": [],
"stream": sys.stdout,
},
"color_console": {
"class": "logging.StreamHandler",
"formatter": "colorful",
"filters": [],
"stream": sys.stdout,
},
},
"root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
"loggers": {
"axolotl": {
"handlers": ["color_console"],
"level": "DEBUG",
"propagate": False,
},
"axolotl": {"handlers": ["console"], "level": "DEBUG", "propagate": False},
},
}
def configure_logging():
"""Configure with default logging"""
init() # Initialize colorama
dictConfig(DEFAULT_LOGGING_CONFIG)

View File

@@ -158,7 +158,7 @@ def flashattn_forward(
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape
is_causal = past_key_value is not None
if cu_seqlens is not None and max_seqlen is not None:
# special handling using sample packing
@@ -169,7 +169,7 @@ def flashattn_forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=is_causal
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape:

View File

@@ -2,10 +2,8 @@
import importlib
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
def load(strategy, tokenizer, cfg, ds_cfg):
def load(strategy, tokenizer, cfg):
try:
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
@@ -13,9 +11,6 @@ def load(strategy, tokenizer, cfg, ds_cfg):
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
func = getattr(mod, load_fn)
load_kwargs = {}
if strategy == "user_defined":
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
return func(tokenizer, cfg, **load_kwargs)
return func(tokenizer, cfg)
except Exception: # pylint: disable=broad-exception-caught
return None

View File

@@ -57,8 +57,6 @@ class SystemDataPrompter(AlpacaPrompter):
Alpaca Style Prompter that uses system prompts from the dataset
"""
system_format: str = "### System:\n{system}\n\n"
def build_prompt_w_system(
self,
system: str,

View File

@@ -1,76 +0,0 @@
"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class"""
import logging
from typing import Tuple
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100
# pylint: disable=duplicate-code
class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenizing strategy for the Metharme models
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (prompt["prompt"], "", prompt["generation"])
def _tokenize(
self,
prompt: str,
add_eos_token: bool = True,
strip_bos_token: bool = False,
num_eos_tokens: int = 3,
):
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if len(result["input_ids"]) == 0:
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
# If there's already an EOS token there, subtract from the number added
if result["input_ids"][-1] == self.tokenizer.eos_token_id:
num_eos_tokens -= 1
if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0:
for _ in range(num_eos_tokens):
if len(result["input_ids"]) < self.sequence_len:
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
result["labels"] = result["input_ids"].copy()
return result
class MetharmePrompter(AlpacaPrompter):
"""
Prompter for the Metharme models.
"""
system_prompt = ""
system_no_input_prompt = ""
system_format = ""
turn_format = "{instruction}"
turn_no_input_format = "{instruction}"
def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called
pass
def load(tokenizer, cfg):
return MetharmePromptTokenizingStrategy(
MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
)

View File

@@ -1,98 +0,0 @@
"""
User Defined prompts with configuration from the YML config
"""
from dataclasses import dataclass
from functools import partial
from typing import Optional, Tuple
from axolotl.prompt_strategies.alpaca_w_system import (
InstructionWSystemPromptTokenizingStrategy,
SystemDataPrompter,
)
@dataclass
class UserDefinedDatasetConfig:
"""
dataclass configuration representing a userdefined dataset type
"""
system_prompt: str = ""
field_system: str = "system"
field_instruction: str = "instruction"
field_input: str = "input"
field_output: str = "output"
format: str = "{instruction} {input} "
no_input_format: str = "{instruction} "
system_format: str = "{system}"
def __getitem__(self, item):
return getattr(self, item)
class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy):
"""
Prompt Tokenization Strategy for user defined prompts
"""
def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None):
if not ds_cfg:
raise ValueError("Missing dataset prompt configuration")
system_prompt = ""
if ds_cfg.system_prompt:
system_prompt = ds_cfg.system_prompt
def parse_instruction_fields(
field_instruction,
field_input,
field_output,
field_system,
system_prompt,
prompt,
) -> Tuple[str, str, str, str]:
return (
prompt[field_instruction],
prompt[field_input] if field_input in prompt else "",
prompt[field_output] if field_output in prompt else "",
prompt[field_system] if field_system in prompt else system_prompt,
)
turn_format = ds_cfg.format
turn_no_input_format = ds_cfg.no_input_format
system_format = ds_cfg.system_format
class UserDefinedPrompter(SystemDataPrompter):
"""
Prompter for user defined prompts
"""
def match_prompt_style(self):
self.turn_format = turn_format
self.turn_no_input_format = turn_no_input_format
self.system_format = system_format
prompter = UserDefinedPrompter()
strat = UserDefinedPromptTokenizationStrategy(
prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
setattr(
strat,
"parse_instruction_fields",
partial(
parse_instruction_fields,
ds_cfg.field_instruction,
ds_cfg.field_input,
ds_cfg.field_output,
ds_cfg.field_system,
system_prompt,
),
)
return strat

View File

@@ -85,11 +85,7 @@ class PromptTokenizingStrategy(abc.ABC):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if (
len(result["input_ids"]) > 0
and result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]

View File

@@ -26,7 +26,7 @@ class AlpacaPrompter:
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
system_format: str = "{system}"
system_format: str
turn_format: str
turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None
@@ -63,17 +63,13 @@ class AlpacaPrompter:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = (
self.system_format.format(system=self.system_prompt)
if self.system_prompt
else ""
) + self.turn_format.format(instruction=instruction, input=input)
res = self.system_prompt + self.turn_format.format(
instruction=instruction, input=input
)
else:
res = (
self.system_format.format(system=self.system_no_input_prompt)
if self.system_prompt
else ""
) + self.turn_no_input_format.format(instruction=instruction)
res = self.system_no_input_prompt + self.turn_no_input_format.format(
instruction=instruction
)
if output:
res = f"{res}{output}"
yield res

View File

@@ -62,13 +62,6 @@ def normalize_config(cfg):
else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
if cfg.bf16 or cfg.bfloat16:
cfg.torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
cfg.torch_dtype = torch.float16
else:
cfg.torch_dtype = torch.float32
log_gpu_memory_usage(LOG, "baseline", cfg.device)

View File

@@ -41,7 +41,6 @@ from axolotl.prompters import (
ShareGPTPrompter,
SummarizeTLDRPrompter,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.trainer import (
calculate_total_num_steps,
@@ -161,15 +160,8 @@ def load_tokenized_prepared_datasets(
split=None,
)
elif local_path.is_file():
ds_type = "json"
if d.ds_type:
ds_type = d.ds_type
elif ".parquet" in d.path:
ds_type = "parquet"
elif ".arrow" in d.path:
ds_type = "arrow"
ds = load_dataset(
ds_type,
"json",
name=d.name,
data_files=d.path,
streaming=False,
@@ -206,27 +198,13 @@ def load_tokenized_prepared_datasets(
)
else:
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
d_base_type = d_prompt_style = None
d_type = d.type
if isinstance(d_type, str):
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if "train" in ds:
ds = ds["train"]
if (
"input_ids" in ds.features
and "attention_mask" in ds.features
and "labels" in ds.features
):
# dataset is already tokenized, just drop it straight in
datasets.append(ds)
elif isinstance(d.type, DictDefault):
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif ds_strategy := load(d.type, tokenizer, cfg, d):
if ds_strategy := load(d.type, tokenizer, cfg):
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "alpaca":

View File

@@ -11,7 +11,6 @@ import bitsandbytes as bnb
import torch
import transformers
from optimum.bettertransformer import BetterTransformer
from peft.tuners.lora import LoraLayer
from transformers import ( # noqa: F401
AutoConfig,
AutoModelForCausalLM,
@@ -147,6 +146,12 @@ def load_model(
LOG.info("patching _expand_mask")
hijack_expand_mask()
if cfg.bf16 or cfg.bfloat16:
torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
try:
if cfg.gptq:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
@@ -178,7 +183,7 @@ def load_model(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=cfg.torch_dtype,
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
@@ -237,7 +242,7 @@ def load_model(
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
torch_dtype=torch_dtype,
**model_kwargs,
)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
@@ -272,7 +277,7 @@ def load_model(
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
torch_dtype=torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -303,7 +308,7 @@ def load_model(
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
torch_dtype=torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -317,7 +322,7 @@ def load_model(
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
torch_dtype=torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -351,6 +356,16 @@ def load_model(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if cfg.flash_attention and cfg.is_llama_derived_model:
for name, module in model.named_modules():
if "norm" in name:
module.to(torch_dtype)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module.to(torch_dtype)
model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit:
@@ -423,7 +438,7 @@ def load_llama_adapter(model, cfg):
)
if cfg.lora_model_dir:
LOG.debug("Loading pretained PEFT - llama_adapter")
LOG.info("Loading pretained LORA")
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
@@ -485,7 +500,6 @@ def load_lora(model, cfg):
)
if cfg.lora_model_dir:
LOG.debug("Loading pretained PEFT - LoRA")
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
@@ -494,22 +508,6 @@ def load_lora(model, cfg):
else:
model = get_peft_model(model, lora_config)
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
module = module.to(cfg.torch_dtype)
if "norm" in name:
module = module.to(torch.float32)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module = module.to(cfg.torch_dtype)
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if cfg.flash_attention and cfg.is_llama_derived_model:
for name, module in model.named_modules():
if "norm" in name:
module = module.to(cfg.torch_dtype)
model.print_trainable_parameters()
return model, lora_config

View File

@@ -19,10 +19,7 @@ from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import (
SequentialDistributedSampler,
get_parameter_names,
)
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import (
GPUStatsCallback,
@@ -174,18 +171,6 @@ class AxolotlTrainer(Trainer):
)
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing:
return SequentialDistributedSampler(
eval_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
batch_size=self.args.per_device_eval_batch_size,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
train_sampler = self._get_train_sampler()
@@ -210,7 +195,6 @@ class AxolotlTrainer(Trainer):
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
return self.accelerator.prepare(
MultipackDistributedDataloader(
@@ -284,15 +268,15 @@ def disable_datasets_caching():
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
if eval_dataset:
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
if cfg.sample_packing:
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
add_position_ids, num_proc=os.cpu_count()
)
if eval_dataset:
eval_dataset = eval_dataset.map(add_position_ids, num_proc=os.cpu_count())
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
add_position_ids, num_proc=os.cpu_count()
)
return train_dataset, eval_dataset
@@ -371,16 +355,10 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
def setup_fsdp_envs(cfg):
os.environ["ACCELERATE_USE_FSDP"] = "true"
if cfg.fsdp_config.fsdp_offload_params:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
if cfg.fsdp_config.fsdp_sync_module_states:
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
if cfg.fsdp_config.fsdp_state_dict_type:
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:
os.environ[
"FSDP_TRANSFORMER_CLS_TO_WRAP"
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
@@ -477,13 +455,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
# we have an eval set, but no steps defined, use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
if cfg.save_strategy:
training_arguments_kwargs["save_strategy"] = cfg.save_strategy
else:
training_arguments_kwargs["save_strategy"] = (
"steps" if cfg.save_steps else "epoch"
)
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len,
@@ -495,6 +466,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate,
save_strategy="steps" if cfg.save_steps else "epoch",
save_steps=cfg.save_steps,
output_dir=cfg.output_dir,
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,

File diff suppressed because one or more lines are too long