Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
83d904a27d fix the context manager call
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-09-03 22:49:28 -04:00
Wing Lian
5e4a760ad8 start to swap out for accelerate partial state 2023-09-03 22:41:29 -04:00
18 changed files with 251 additions and 386 deletions

View File

@@ -23,6 +23,11 @@ jobs:
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.1
axolotl_extras: gptq
runs-on: self-hosted runs-on: self-hosted
steps: steps:
- name: Checkout - name: Checkout
@@ -68,6 +73,11 @@ jobs:
pytorch: 2.0.1 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
- cuda: 118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.1
axolotl_extras: gptq
runs-on: self-hosted runs-on: self-hosted
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -24,7 +24,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip install -e . pip install -e .[peft]
pip install -r requirements-tests.txt pip install -r requirements-tests.txt
- name: Run tests - name: Run tests

View File

@@ -11,13 +11,14 @@ RUN apt-get update && \
WORKDIR /workspace WORKDIR /workspace
RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main"
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN cd axolotl && \ RUN cd axolotl && \
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[flash-attn,gptq,$AXOLOTL_EXTRAS]; \ pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
else \ else \
pip install -e .[flash-attn,gptq]; \ pip install -e .[flash-attn]; \
fi fi
# fix so that git fetch/pull from remote works # fix so that git fetch/pull from remote works

View File

@@ -0,0 +1,8 @@
# LLaMa 7B using LoRA
This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
```shell
accelerate launch scripts/finetune.py examples/gptq-lora-7b/config.yml
```

View File

@@ -0,0 +1,63 @@
base_model: Neko-Institute-of-Science/LLaMA-7B-4bit-128g
base_model_config: Neko-Institute-of-Science/LLaMA-7B-4bit-128g
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code:
load_in_8bit: true
gptq: true
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project: llama-7b-lora-int4
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./llama-7b-lora-int4
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0000002
train_on_inputs: false
group_by_length: false
fp16: true
bf16: false
tf32: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 5
xformers_attention:
flash_attention:
gradient_checkpointing: true
gptq_groupsize: 128
gptq_model_v1: false
warmup_steps: 20
eval_steps: 110
save_steps: 660
debug:
deepspeed:
weight_decay: 0.0001
fsdp:
fsdp_config:
tokens:
pad_token: "<pad>"
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,76 +0,0 @@
base_model: TheBloke/Llama-2-7B-GPTQ
base_model_config: TheBloke/Llama-2-7B-GPTQ
is_llama_derived_model: false
gptq: true
gptq_bits: 4
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
tokenizer_use_fast: true
tokenizer_legacy: true
load_in_8bit: false
load_in_4bit: false
strict: false
push_dataset_to_hub:
hf_use_auth_token: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
adapter: lora
lora_model_dir:
sequence_len: 4096
sample_packing:
lora_r: 8
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- k_proj
- o_proj
- q_proj
- v_proj
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./model-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_torch
adam_beta2: 0.95
adam_eps: 0.00001
max_grad_norm: 1.0
torchdistx_path:
lr_scheduler: cosine
lr_quadratic_warmup: true
learning_rate: 0.000017
train_on_inputs: false
group_by_length: false
bf16: false
fp16: false
float16: true
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
sdp_attention:
flash_optimum:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 100
eval_steps:
save_steps:
debug:
deepspeed:
weight_decay: 0.1
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,7 +1,3 @@
--extra-index-url https://download.pytorch.org/whl/cu118
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
torch==2.0.1
auto-gptq
packaging packaging
peft @ git+https://github.com/huggingface/peft.git peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git transformers @ git+https://github.com/huggingface/transformers.git

View File

@@ -24,7 +24,7 @@ from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_tokenizer from axolotl.utils.models import load_model_config, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.wandb import setup_wandb_env_vars from axolotl.utils.wandb import setup_wandb_env_vars
@@ -216,6 +216,15 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
else: else:
cfg[k] = kwargs[k] cfg[k] = kwargs[k]
model_config = load_model_config(cfg)
# figure out if the model is llama
cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
)
validate_config(cfg) validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)

View File

@@ -2,27 +2,15 @@
from setuptools import find_packages, setup from setuptools import find_packages, setup
install_requires = []
def parse_requirements(): with open("./requirements.txt", encoding="utf-8") as requirements_file:
_install_requires = [] # don't include peft yet until we check the int4
_dependency_links = [] # need to manually install peft for now...
with open("./requirements.txt", encoding="utf-8") as requirements_file: reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
lines = [ reqs = [r for r in reqs if "flash-attn" not in r]
r.strip() for r in requirements_file.readlines() if "auto-gptq" not in r reqs = [r for r in reqs if r and r[0] != "#"]
] for r in reqs:
for line in lines: install_requires.append(r)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif "flash-attn" not in line and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
return _install_requires, _dependency_links
install_requires, dependency_links = parse_requirements()
setup( setup(
name="axolotl", name="axolotl",
@@ -31,10 +19,12 @@ setup(
package_dir={"": "src"}, package_dir={"": "src"},
packages=find_packages(), packages=find_packages(),
install_requires=install_requires, install_requires=install_requires,
dependency_links=dependency_links,
extras_require={ extras_require={
"gptq": [ "gptq": [
"auto-gptq", "alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
],
"gptq_triton": [
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
], ],
"flash-attn": [ "flash-attn": [
"flash-attn==2.0.8", "flash-attn==2.0.8",
@@ -42,5 +32,8 @@ setup(
"extras": [ "extras": [
"deepspeed", "deepspeed",
], ],
"peft": [
"peft @ git+https://github.com/huggingface/peft.git",
],
}, },
) )

View File

@@ -1,144 +0,0 @@
import logging
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Union
from datasets import Dataset as Dataset_ds
from datasets import DatasetDict, IterableDataset, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
logger = logging.getLogger("axolotl")
class DsType(Enum):
JSON = "json"
ARROW = "arrow"
PARQUET = "parquet"
@dataclass
class DatasetConfiguration:
path: str
type: str
name: Optional[str] = field(
default=None,
metadata={"help": "the name of the dataset configuration to load."},
)
ds_type: Optional[DsType] = None
data_files: Optional[Union[str, List[str]]] = None
shards: Optional[int] = None
test_size: Optional[float] = None
@staticmethod
def from_dict(d: Dict[str, Any]) -> Generator["DatasetConfiguration", None, None]:
if "name" in d and isinstance(d["name"], list):
name = d.pop("name")
for n in name:
yield DatasetConfiguration(
**d,
name=n,
)
def load_dataset_from_local(config: DatasetConfiguration) -> Optional[Dataset_ds]:
local_path = Path(config.path)
if not local_path.exists():
return None
ds = None
if local_path.is_dir():
if config.ds_type:
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
ds = load_from_disk(config.path)
else:
ds = load_dataset(
config.path,
name=config.name,
data_files=config.data_files,
streaming=False,
split=None,
)
elif local_path.is_file():
ds_type = "json"
if config.ds_type:
ds_type = config.ds_type.value
elif "parquet" in config.path:
ds_type = "parquet"
elif "arrow" in config.path:
ds_type = "arrow"
ds = load_dataset(
ds_type,
name=config.name,
data_files=config.path,
streaming=False,
split=None, # is this correct?
)
if not ds:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
return ds
# TODO should this be a DatasetDict?
class Dataset(Dataset_ds):
_config: DatasetConfiguration
def __init__(self, *args, config: DatasetConfiguration = None, **kwargs):
self._config = config
super().__init__(*args, **kwargs)
@staticmethod
def from_config(
config: DatasetConfiguration,
token: bool = False,
default_test_size: float = 0.1,
):
ds = load_dataset_from_local(config)
if not ds:
try:
ds = load_dataset(
config.path,
name=config.name,
data_files=config.data_files,
token=token,
)
except FileNotFoundError:
pass
if not ds:
fp = hf_hub_download(
repo_id=config.path,
repo_type="dataset",
filename=config.data_files,
token=token,
)
ds = load_dataset(
"json", name=config.name, data_files=fp, streaming=False, split=None
)
if not ds:
raise ValueError("unhandled dataset load")
test_size = config.test_size if config.test_size else default_test_size
# determine if the dataset is pre-tokenized
check_ds = ds["train"] if isinstance(ds, DatasetDict) and "train" in ds else ds
is_ds_tokenized = False
if "input_ids" in check_ds.features:
is_ds_tokenized = True
if "attention_mask" not in check_ds.features:
logger.warning("`attention_mask` missing from pre-tokenized dataset")
if "labels" not in check_ds.features:
logger.warning("`labels` missing from pre-tokenized dataset")
if test_size and (not isinstance(ds, DatasetDict) or "test" not in ds):
ds.train_test_split(test_size=test_size, shuffle=False)
pass
class DatasetCollection:
datasets: List[Dataset] = []
def __init__(self, datasets: Union[Dataset, List[Dataset]]):
self.datasets = datasets if isinstance(datasets, list) else [datasets]
def __iter__(self):
for ds in self.datasets:
for d in ds:
yield d

View File

@@ -2,9 +2,7 @@
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
import logging
import warnings import warnings
from functools import partial
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@@ -35,9 +33,6 @@ except ImportError:
) )
LOG = logging.getLogger("axolotl")
def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False): def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask _prepare_decoder_attention_mask
@@ -49,34 +44,6 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
llama_model_forward llama_model_forward
) )
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.info(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
try:
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.info(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention # Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask # requires the attention mask to be the same as the key_padding_mask

View File

@@ -309,6 +309,10 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
) )
def build_prompt(self, source) -> Generator[str, None, None]: def build_prompt(self, source) -> Generator[str, None, None]:
# ignore the system prompt if provided
if source[0]["from"] == "system":
source.pop(0)
if len(source) < 2: if len(source) < 2:
# If there isn't a back and forth conversation, ignore it # If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations # also happens on the data splitting leaving empty conversations
@@ -317,12 +321,6 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
) )
conv = self._conversation.copy() conv = self._conversation.copy()
# Add the conversation system prompt if provided, otherwise use the default one
if source[0]["from"] == "system":
conv.system = source[0]["value"]
source.pop(0)
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
try: try:

View File

@@ -11,6 +11,7 @@ import numpy as np
import pandas as pd import pandas as pd
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate.state import PartialState
from datasets import load_dataset from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm from tqdm import tqdm
@@ -24,12 +25,9 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import ( from axolotl.utils.distributed import (
barrier,
gather_scalar_from_all_ranks, gather_scalar_from_all_ranks,
get_world_size, get_world_size,
is_distributed,
is_main_process, is_main_process,
zero_first,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -37,6 +35,7 @@ if TYPE_CHECKING:
LOG = logging.getLogger("axolotl.callbacks") LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100 IGNORE_INDEX = -100
dist_state = PartialState()
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -211,7 +210,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
"subject": example["subject"], "subject": example["subject"],
} }
with zero_first(is_main_process()): with dist_state.main_process_first():
bench_dataset = bench_dataset.map(tokenize_evals) bench_dataset = bench_dataset.map(tokenize_evals)
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx) bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
@@ -259,7 +258,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
bench_names[s]["preds"].append(p) bench_names[s]["preds"].append(p)
bench_names[s]["refs"].append(r) bench_names[s]["refs"].append(r)
barrier() dist_state.wait_for_everyone()
local_bench_names = bench_names local_bench_names = bench_names
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())] gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
# Gather results from all GPUs to GPU 0 # Gather results from all GPUs to GPU 0
@@ -271,13 +270,10 @@ def bench_eval_callback_factory(trainer, tokenizer):
lambda: len(data_loader), get_world_size() lambda: len(data_loader), get_world_size()
) )
if is_distributed() and not is_main_process(): if not is_main_process():
dist.gather_object(local_bench_names, dst=0) dist.gather_object(local_bench_names, dst=0)
else: else:
if is_distributed(): dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
else:
gathered_bench_names = [local_bench_names]
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks) bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
results = {f"{bench_split}_bench_loss": bench_loss} results = {f"{bench_split}_bench_loss": bench_loss}

View File

@@ -6,7 +6,6 @@ import os
import torch import torch
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.models import load_model_config
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -70,16 +69,6 @@ def normalize_config(cfg):
else: else:
cfg.torch_dtype = torch.float32 cfg.torch_dtype = torch.float32
model_config = load_model_config(cfg)
# figure out if the model is llama
cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
)
log_gpu_memory_usage(LOG, "baseline", cfg.device) log_gpu_memory_usage(LOG, "baseline", cfg.device)
@@ -108,7 +97,9 @@ def validate_config(cfg):
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.", "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
) )
if cfg.load_4bit: if cfg.load_4bit:
raise ValueError("cfg.load_4bit parameter has been deprecated") raise ValueError(
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
)
if cfg.adapter == "qlora": if cfg.adapter == "qlora":
if cfg.merge_lora: if cfg.merge_lora:

View File

@@ -7,6 +7,7 @@ from pathlib import Path
from typing import Tuple, Union from typing import Tuple, Union
import torch import torch
from accelerate.state import PartialState
from datasets import ( from datasets import (
Dataset, Dataset,
DatasetDict, DatasetDict,
@@ -42,7 +43,6 @@ from axolotl.prompters import (
SummarizeTLDRPrompter, SummarizeTLDRPrompter,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.trainer import ( from axolotl.utils.trainer import (
calculate_total_num_steps, calculate_total_num_steps,
process_datasets_for_packing, process_datasets_for_packing,
@@ -50,11 +50,12 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
state = PartialState()
def prepare_dataset(cfg, tokenizer): def prepare_dataset(cfg, tokenizer):
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
with zero_first(is_main_process()): with state.main_process_first():
train_dataset, eval_dataset = load_prepare_datasets( train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
) )
@@ -69,7 +70,7 @@ def prepare_dataset(cfg, tokenizer):
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")
eval_dataset = None eval_dataset = None
with zero_first(is_main_process()): with state.main_process_first():
train_dataset, eval_dataset = process_datasets_for_packing( train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset cfg, train_dataset, eval_dataset
) )
@@ -507,7 +508,7 @@ def load_prepare_datasets(
to_hash_test.encode(), usedforsecurity=False to_hash_test.encode(), usedforsecurity=False
).hexdigest() ).hexdigest()
with zero_first(is_main_process()): with state.main_process_first():
dataset = dataset.train_test_split( dataset = dataset.train_test_split(
test_size=cfg.val_set_size, test_size=cfg.val_set_size,
shuffle=False, shuffle=False,

View File

@@ -1,29 +1,27 @@
""" """
utility helpers for distributed checks utility helpers for distributed checks
""" """
import os
from contextlib import contextmanager
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate import Accelerator from accelerate import DistributedType
from accelerate.state import PartialState
from accelerate.utils import wait_for_everyone
accelerate = None # pylint: disable=invalid-name accelerate = None # pylint: disable=invalid-name
state = PartialState()
def load_accelerate():
global accelerate # pylint: disable=global-statement
accelerate = Accelerator()
def is_distributed(): def is_distributed():
""" """
Check if distributed training is initialized. Check if distributed training is initialized.
""" """
global accelerate # pylint: disable=global-statement return state.distributed_type in (
if not accelerate: DistributedType.MULTI_GPU,
accelerate = Accelerator() DistributedType.MULTI_CPU,
return dist.is_available() and dist.is_initialized() DistributedType.DEEPSPEED,
DistributedType.FSDP,
)
def barrier(): def barrier():
@@ -31,34 +29,19 @@ def barrier():
Acts as a barrier to wait for all processes. This ensures that all processes Acts as a barrier to wait for all processes. This ensures that all processes
reach the barrier before proceeding further. reach the barrier before proceeding further.
""" """
if is_distributed(): wait_for_everyone()
dist.barrier()
def is_main_process(): def is_main_process() -> bool:
""" """
Check if the current process is the main process. Check if the current process is the main process.
If not in distributed mode, always return True. If not in distributed mode, always return True.
""" """
if not is_distributed(): return state.is_main_process
return True
return dist.get_rank() == 0
def get_world_size(): def get_world_size() -> int:
return int(os.getenv("WORLD_SIZE", "1")) return state.num_processes
@contextmanager
def zero_first(is_main):
"""
runs the wrapped context so that rank 0 runs first before other ranks
"""
if not is_main: # other ranks wait first
barrier()
yield
if is_main: # then rank 0 waits after it has run the context
barrier()
def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
@@ -74,11 +57,9 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
- A list of computed values from all ranks if on the gathering rank, otherwise None. - A list of computed values from all ranks if on the gathering rank, otherwise None.
""" """
value_scalar = fn() value_scalar = fn()
if not is_distributed():
return [value_scalar]
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float() value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
if not is_main_process(): if not state.is_main_process:
dist.gather(value_tensor, dst=0) dist.gather(value_tensor, dst=0)
else: else:
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)] gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]

View File

@@ -4,19 +4,19 @@
import logging import logging
import math import math
import os import os
from pathlib import Path
from typing import Optional, Tuple # noqa: F401 from typing import Optional, Tuple # noqa: F401
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig, prepare_model_for_kbit_training from peft import PeftConfig
from transformers import ( # noqa: F401 from transformers import ( # noqa: F401
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
BitsAndBytesConfig, BitsAndBytesConfig,
GPTQConfig,
LlamaConfig, LlamaConfig,
PreTrainedModel, PreTrainedModel,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
@@ -155,17 +155,32 @@ def load_model(
LOG.info("patching _expand_mask") LOG.info("patching _expand_mask")
hijack_expand_mask() hijack_expand_mask()
try:
if cfg.gptq:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
replace_peft_model_with_int4_lora_model,
)
replace_peft_model_with_int4_lora_model()
except Exception as err:
LOG.exception(err)
raise err
if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
):
try:
from peft import prepare_model_for_kbit_training
except ImportError:
# For backward compatibility
from peft import (
prepare_model_for_int8_training as prepare_model_for_kbit_training,
)
model_kwargs = {} model_kwargs = {}
if cfg.model_revision: if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision model_kwargs["revision"] = cfg.model_revision
if cfg.gptq:
model_config = load_model_config(cfg)
if hasattr(model_config, "quantization_config"):
LOG.warning("model config does not contain quantization_config information")
else:
model_kwargs["quantization_config"] = GPTQConfig(
**model_config.quantization_config
)
if cfg.adapter == "qlora" and cfg.load_in_4bit: if cfg.adapter == "qlora" and cfg.load_in_4bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig( model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
@@ -176,7 +191,45 @@ def load_model(
bnb_4bit_quant_type="nf4", bnb_4bit_quant_type="nf4",
) )
try: try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: if cfg.gptq and cfg.is_llama_derived_model:
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
from huggingface_hub import snapshot_download
try:
snapshot_download_kwargs = {}
if cfg.base_model_ignore_patterns:
snapshot_download_kwargs[
"ignore_patterns"
] = cfg.base_model_ignore_patterns
cache_model_path = Path(
snapshot_download(base_model, **snapshot_download_kwargs)
)
files = (
list(cache_model_path.glob("*.pt"))
+ list(cache_model_path.glob("*.safetensors"))
+ list(cache_model_path.glob("*.bin"))
)
if len(files) > 0:
model_path = str(files[0])
else:
LOG.warning(
"unable to find a cached model file, this will likely fail..."
)
model_path = str(cache_model_path)
except Exception: # pylint: disable=broad-exception-caught
model_path = cfg.base_model
model, _ = load_llama_model_4bit_low_ram(
base_model_config if base_model_config else base_model,
model_path,
device_map=cfg.device_map,
half=cfg.fp16,
groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
is_v1_model=cfg.gptq_model_v1
if cfg.gptq_model_v1 is not None
else True,
)
load_in_8bit = False
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
config_kwargs = {} config_kwargs = {}
@@ -222,24 +275,15 @@ def load_model(
# ) # )
# model.train() # sets to train instead of eval mode # model.train() # sets to train instead of eval mode
elif model_type and not cfg.trust_remote_code: elif model_type and not cfg.trust_remote_code:
if cfg.gptq: model = getattr(transformers, model_type).from_pretrained(
model = AutoModelForCausalLM.from_pretrained( base_model,
base_model, device_map=cfg.device_map,
device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False, torch_dtype=cfg.torch_dtype,
**model_kwargs, trust_remote_code=cfg.trust_remote_code or False,
) **model_kwargs,
else: )
model = getattr(transformers, model_type).from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else: else:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
base_model, base_model,
@@ -315,12 +359,11 @@ def load_model(
module.to(torch.float32) module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp needs_fa2_dtype = cfg.adapter or cfg.fsdp
if (cfg.adapter == "lora" and load_in_8bit) or ( if not cfg.gptq and (
cfg.adapter == "qlora" and cfg.load_in_4bit (cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
): ):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training( model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing model, use_gradient_checkpointing=cfg.gradient_checkpointing
) )
@@ -342,10 +385,22 @@ def load_model(
if cfg.ddp and not load_in_8bit: if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}") model.to(f"cuda:{cfg.local_rank}")
if cfg.gptq:
# Scales to half
LOG.info("Fitting 4bit scales and zeros to half")
for _, module in model.named_modules():
if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
type(module)
):
if hasattr(module, "is_v1_model") and module.is_v1_model:
module.zeros = module.zeros.half()
module.scales = module.scales.half()
module.bias = module.bias.half()
if ( if (
torch.cuda.device_count() > 1 torch.cuda.device_count() > 1
and int(os.getenv("WORLD_SIZE", "1")) > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1
and (cfg.load_in_4bit) and (cfg.gptq or cfg.load_in_4bit)
): ):
# llama is PROBABLY model parallelizable, but the default isn't that it is # llama is PROBABLY model parallelizable, but the default isn't that it is
# so let's only set it for the 4bit, see # so let's only set it for the 4bit, see

View File

@@ -514,7 +514,23 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
training_arguments_kwargs["seed"] = cfg.seed training_arguments_kwargs["seed"] = cfg.seed
if cfg.gradient_checkpointing: if cfg.gradient_checkpointing:
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing if cfg.gptq:
from alpaca_lora_4bit.gradient_checkpointing import (
apply_gradient_checkpointing,
)
gradient_checkpointing_ratio = (
cfg.gradient_checkpointing_ratio
if cfg.gradient_checkpointing_ratio
else 1.0
)
apply_gradient_checkpointing(
model, checkpoint_ratio=gradient_checkpointing_ratio
)
else:
training_arguments_kwargs[
"gradient_checkpointing"
] = cfg.gradient_checkpointing
if cfg.fsdp: if cfg.fsdp:
training_arguments_kwargs["fsdp"] = cfg.fsdp training_arguments_kwargs["fsdp"] = cfg.fsdp
if cfg.fsdp_config: if cfg.fsdp_config: