Compare commits
23 Commits
fsdp-defau
...
benchmark-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3de28942c | ||
|
|
45848a9285 | ||
|
|
d6cea18034 | ||
|
|
606846e0a5 | ||
|
|
a6c9223114 | ||
|
|
8b16ecd448 | ||
|
|
f5db88a10d | ||
|
|
99d844f215 | ||
|
|
aefd4d74fa | ||
|
|
24b0e93235 | ||
|
|
2455254b92 | ||
|
|
918e040601 | ||
|
|
ef062d8fcb | ||
|
|
d4c8b66f3d | ||
|
|
64e9824d3e | ||
|
|
1134654c98 | ||
|
|
2fc756c289 | ||
|
|
943b84c490 | ||
|
|
6f166464d8 | ||
|
|
e3b07402a7 | ||
|
|
8d3c8a3eab | ||
|
|
c30120e684 | ||
|
|
9aed60fa54 |
20
README.md
20
README.md
@@ -328,15 +328,6 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
name: enron_emails
|
name: enron_emails
|
||||||
type: completion # format from earlier
|
type: completion # format from earlier
|
||||||
|
|
||||||
# huggingface repo with multiple named configurations/subsets
|
|
||||||
datasets:
|
|
||||||
- path: bigcode/commitpackft
|
|
||||||
name:
|
|
||||||
- ruby
|
|
||||||
- python
|
|
||||||
- typescript
|
|
||||||
type: ... # unimplemented custom format
|
|
||||||
|
|
||||||
# local
|
# local
|
||||||
datasets:
|
datasets:
|
||||||
- path: data.jsonl # or json
|
- path: data.jsonl # or json
|
||||||
@@ -416,10 +407,6 @@ fp16: true
|
|||||||
# Use CUDA tf32
|
# Use CUDA tf32
|
||||||
tf32: true # require >=ampere
|
tf32: true # require >=ampere
|
||||||
|
|
||||||
# No AMP (automatic mixed precision)
|
|
||||||
bfloat16: true # require >=ampere
|
|
||||||
float16: true
|
|
||||||
|
|
||||||
# a list of one or more datasets to finetune the model with
|
# a list of one or more datasets to finetune the model with
|
||||||
datasets:
|
datasets:
|
||||||
# hf dataset repo | "json" for local dataset, make sure to fill data_files
|
# hf dataset repo | "json" for local dataset, make sure to fill data_files
|
||||||
@@ -472,9 +459,6 @@ dataset_shard_idx:
|
|||||||
# the maximum length of an input to train with, this should typically be less than 2048
|
# the maximum length of an input to train with, this should typically be less than 2048
|
||||||
# as most models have a token/context limit of 2048
|
# as most models have a token/context limit of 2048
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
# pad inputs so each step uses constant sized buffers
|
|
||||||
# this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
|
||||||
pad_to_sequence_len:
|
|
||||||
# max sequence length to concatenate training samples together up to
|
# max sequence length to concatenate training samples together up to
|
||||||
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||||
# FutureWarning: This will soon be DEPRECATED
|
# FutureWarning: This will soon be DEPRECATED
|
||||||
@@ -626,6 +610,9 @@ deepspeed:
|
|||||||
# Path to torch distx for optim 'adamw_anyprecision'
|
# Path to torch distx for optim 'adamw_anyprecision'
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
|
|
||||||
|
# Set padding for data collator to 'longest'
|
||||||
|
collator_pad_to_longest:
|
||||||
|
|
||||||
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
|
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
|
||||||
pretraining_dataset:
|
pretraining_dataset:
|
||||||
|
|
||||||
@@ -665,7 +652,6 @@ fsdp:
|
|||||||
fsdp_config:
|
fsdp_config:
|
||||||
fsdp_offload_params: true
|
fsdp_offload_params: true
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_sync_module_states: true
|
|
||||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ RUN apt-get update && \
|
|||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main"
|
RUN pip3 install --force-reinstall "peft @ git+https://github.com/huggingface/peft.git@main"
|
||||||
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN cd axolotl && \
|
RUN cd axolotl && \
|
||||||
|
|||||||
@@ -47,3 +47,4 @@ local_rank:
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
fsdp:
|
fsdp:
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
|
collator_pad_to_longest: true
|
||||||
|
|||||||
@@ -25,4 +25,3 @@ rouge-score==0.1.2
|
|||||||
scipy
|
scipy
|
||||||
scikit-learn==1.2.2
|
scikit-learn==1.2.2
|
||||||
pynvml
|
pynvml
|
||||||
art
|
|
||||||
|
|||||||
@@ -6,17 +6,14 @@ import os
|
|||||||
import random
|
import random
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
# add src to the pythonpath so we don't need to pip install this
|
# add src to the pythonpath so we don't need to pip install this
|
||||||
from art import text2art
|
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from transformers import GenerationConfig, TextStreamer
|
from transformers import GenerationConfig, TextStreamer
|
||||||
|
|
||||||
@@ -25,7 +22,7 @@ from axolotl.utils.config import normalize_config, validate_config
|
|||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
from axolotl.utils.models import load_model, load_model_config, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||||
@@ -40,26 +37,16 @@ LOG = logging.getLogger("axolotl.scripts")
|
|||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
def print_axolotl_text_art():
|
||||||
class TrainerCliArgs:
|
ascii_art = """
|
||||||
"""
|
dP dP dP
|
||||||
dataclass representing the various non-training arguments
|
88 88 88
|
||||||
"""
|
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
|
||||||
|
88' `88 `8bd8' 88' `88 88 88' `88 88 88
|
||||||
|
88. .88 .d88b. 88. .88 88 88. .88 88 88
|
||||||
|
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
|
||||||
|
"""
|
||||||
|
|
||||||
debug: bool = field(default=False)
|
|
||||||
inference: bool = field(default=False)
|
|
||||||
merge_lora: bool = field(default=False)
|
|
||||||
prepare_ds_only: bool = field(default=False)
|
|
||||||
prompter: Optional[str] = field(default=None)
|
|
||||||
shard: bool = field(default=False)
|
|
||||||
|
|
||||||
|
|
||||||
def print_axolotl_text_art(suffix=None):
|
|
||||||
font = "nancyj"
|
|
||||||
ascii_text = " axolotl"
|
|
||||||
if suffix:
|
|
||||||
ascii_text += f" x {suffix}"
|
|
||||||
ascii_art = text2art(" axolotl", font=font)
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
print(ascii_art)
|
print(ascii_art)
|
||||||
|
|
||||||
@@ -74,8 +61,6 @@ def get_multi_line_input() -> Optional[str]:
|
|||||||
|
|
||||||
|
|
||||||
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
|
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
|
||||||
if prompter == "None":
|
|
||||||
prompter = None
|
|
||||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||||
|
|
||||||
for token, symbol in default_tokens.items():
|
for token, symbol in default_tokens.items():
|
||||||
@@ -150,10 +135,6 @@ def choose_config(path: Path):
|
|||||||
"No YAML config files found in the specified directory. Are you using a .yml extension?"
|
"No YAML config files found in the specified directory. Are you using a .yml extension?"
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(yaml_files) == 1:
|
|
||||||
print(f"Using default YAML file '{yaml_files[0]}'")
|
|
||||||
return yaml_files[0]
|
|
||||||
|
|
||||||
print("Choose a YAML file:")
|
print("Choose a YAML file:")
|
||||||
for idx, file in enumerate(yaml_files):
|
for idx, file in enumerate(yaml_files):
|
||||||
print(f"{idx + 1}. {file}")
|
print(f"{idx + 1}. {file}")
|
||||||
@@ -177,20 +158,45 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
|
|||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
*,
|
config: Path = Path("configs/"),
|
||||||
cfg: DictDefault,
|
prepare_ds_only: bool = False,
|
||||||
cli_args: TrainerCliArgs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
print_axolotl_text_art()
|
||||||
|
if Path(config).is_dir():
|
||||||
|
config = choose_config(config)
|
||||||
|
|
||||||
|
# load the config from the yaml file
|
||||||
|
with open(config, encoding="utf-8") as file:
|
||||||
|
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||||
|
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||||
|
# then overwrite the value
|
||||||
|
cfg_keys = cfg.keys()
|
||||||
|
for k, _ in kwargs.items():
|
||||||
|
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||||
|
if k in cfg_keys or not cfg.strict:
|
||||||
|
# handle booleans
|
||||||
|
if isinstance(cfg[k], bool):
|
||||||
|
cfg[k] = bool(kwargs[k])
|
||||||
|
else:
|
||||||
|
cfg[k] = kwargs[k]
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
normalize_config(cfg)
|
||||||
|
|
||||||
|
setup_wandb_env_vars(cfg)
|
||||||
|
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
|
||||||
if not (
|
if (
|
||||||
cli_args.shard or cli_args.merge_lora or cli_args.inference
|
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
||||||
): # don't need to load dataset for these
|
): # don't need to load dataset for these
|
||||||
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
|
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
|
||||||
|
|
||||||
if cli_args.debug or cfg.debug:
|
if cfg.debug or "debug" in kwargs:
|
||||||
LOG.info("check_dataset_labels...")
|
LOG.info("check_dataset_labels...")
|
||||||
check_dataset_labels(
|
check_dataset_labels(
|
||||||
train_dataset.select(
|
train_dataset.select(
|
||||||
@@ -199,17 +205,17 @@ def train(
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cli_args.prepare_ds_only:
|
if prepare_ds_only:
|
||||||
LOG.info("Finished preparing dataset. Exiting...")
|
LOG.info("Finished preparing dataset. Exiting...")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
LOG.info("loading model and (optionally) peft_config...")
|
LOG.info("loading model and (optionally) peft_config...")
|
||||||
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
model, peft_config = load_model(cfg, tokenizer)
|
||||||
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
if cli_args.merge_lora and cfg.adapter is not None:
|
if "merge_lora" in kwargs and cfg.adapter is not None:
|
||||||
LOG.info("running merge of LoRA with base model")
|
LOG.info("running merge of LoRA with base model")
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
model.to(dtype=torch.float16)
|
model.to(dtype=torch.float16)
|
||||||
@@ -223,13 +229,18 @@ def train(
|
|||||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
return
|
return
|
||||||
|
|
||||||
if cli_args.inference:
|
if cfg.inference:
|
||||||
LOG.debug("Running inference on model")
|
LOG.info("calling do_inference function")
|
||||||
do_inference(cfg, model, tokenizer, prompter=cli_args.prompter)
|
prompter: Optional[str] = "AlpacaPrompter"
|
||||||
|
if "prompter" in kwargs:
|
||||||
|
if kwargs["prompter"] == "None":
|
||||||
|
prompter = None
|
||||||
|
else:
|
||||||
|
prompter = kwargs["prompter"]
|
||||||
|
do_inference(cfg, model, tokenizer, prompter=prompter)
|
||||||
return
|
return
|
||||||
|
|
||||||
if cli_args.shard:
|
if "shard" in kwargs:
|
||||||
LOG.debug("Re-saving model w/ sharding")
|
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -311,51 +322,5 @@ def train(
|
|||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
|
|
||||||
def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|
||||||
if Path(config).is_dir():
|
|
||||||
config = choose_config(config)
|
|
||||||
|
|
||||||
# load the config from the yaml file
|
|
||||||
with open(config, encoding="utf-8") as file:
|
|
||||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
|
||||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
|
||||||
# then overwrite the value
|
|
||||||
cfg_keys = cfg.keys()
|
|
||||||
for k, _ in kwargs.items():
|
|
||||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
|
||||||
if k in cfg_keys or not cfg.strict:
|
|
||||||
# handle booleans
|
|
||||||
if isinstance(cfg[k], bool):
|
|
||||||
cfg[k] = bool(kwargs[k])
|
|
||||||
else:
|
|
||||||
cfg[k] = kwargs[k]
|
|
||||||
|
|
||||||
model_config = load_model_config(cfg)
|
|
||||||
|
|
||||||
# figure out if the model is llama
|
|
||||||
cfg.is_llama_derived_model = (
|
|
||||||
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
|
||||||
or cfg.is_llama_derived_model
|
|
||||||
or "llama" in cfg.base_model
|
|
||||||
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
|
||||||
)
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
normalize_config(cfg)
|
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
def do_train(config: Path = Path("examples/"), **kwargs):
|
|
||||||
print_axolotl_text_art()
|
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
|
||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
|
||||||
return_remaining_strings=True
|
|
||||||
)
|
|
||||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(do_train)
|
fire.Fire(train)
|
||||||
|
|||||||
@@ -1,45 +0,0 @@
|
|||||||
"""
|
|
||||||
Monkeypatch to fix fsdp set state when no previous state was set
|
|
||||||
"""
|
|
||||||
|
|
||||||
import contextlib
|
|
||||||
from typing import Generator, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.distributed.fsdp.api import (
|
|
||||||
OptimStateDictConfig,
|
|
||||||
StateDictConfig,
|
|
||||||
StateDictType,
|
|
||||||
)
|
|
||||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def state_dict_type_patch(
|
|
||||||
module: nn.Module,
|
|
||||||
state_dict_type: StateDictType,
|
|
||||||
state_dict_config: Optional[StateDictConfig] = None,
|
|
||||||
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
|
|
||||||
) -> Generator:
|
|
||||||
prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type(
|
|
||||||
module,
|
|
||||||
state_dict_type,
|
|
||||||
state_dict_config,
|
|
||||||
optim_state_dict_config,
|
|
||||||
)
|
|
||||||
yield
|
|
||||||
if prev_state_dict_settings.state_dict_type:
|
|
||||||
FullyShardedDataParallel.set_state_dict_type(
|
|
||||||
module,
|
|
||||||
prev_state_dict_settings.state_dict_type,
|
|
||||||
prev_state_dict_settings.state_dict_config,
|
|
||||||
prev_state_dict_settings.optim_state_dict_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def replace_fsdp_state_dict_type():
|
|
||||||
torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel.state_dict_type = (
|
|
||||||
state_dict_type_patch
|
|
||||||
)
|
|
||||||
@@ -152,16 +152,6 @@ def validate_config(cfg):
|
|||||||
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
||||||
raise ValueError("FSDP is not supported for falcon models")
|
raise ValueError("FSDP is not supported for falcon models")
|
||||||
|
|
||||||
if (
|
|
||||||
cfg.fsdp
|
|
||||||
and cfg.fsdp_config
|
|
||||||
and cfg.fsdp_config.fsdp_state_dict_type
|
|
||||||
and not cfg.fsdp_config.fsdp_sync_module_states
|
|
||||||
):
|
|
||||||
LOG.warning(
|
|
||||||
"We recommend setting fsdp_config.fsdp_sync_module_states to `true`"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cfg.base_model and "mpt" in cfg.base_model.lower()
|
cfg.base_model and "mpt" in cfg.base_model.lower()
|
||||||
) and cfg.gradient_checkpointing:
|
) and cfg.gradient_checkpointing:
|
||||||
|
|||||||
@@ -134,17 +134,8 @@ def load_tokenized_prepared_datasets(
|
|||||||
seed = 42
|
seed = 42
|
||||||
|
|
||||||
datasets = []
|
datasets = []
|
||||||
|
|
||||||
def for_d_in_datasets(dataset_configs):
|
|
||||||
for dataset in dataset_configs:
|
|
||||||
if dataset.name and isinstance(dataset.name, list):
|
|
||||||
for name in dataset.name:
|
|
||||||
yield DictDefault({**dataset, "name": name})
|
|
||||||
else:
|
|
||||||
yield dataset
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
for d in for_d_in_datasets(cfg.datasets):
|
for d in cfg.datasets:
|
||||||
ds: Union[Dataset, DatasetDict] = None
|
ds: Union[Dataset, DatasetDict] = None
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -5,13 +5,12 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple # noqa: F401
|
from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from peft import PeftConfig
|
|
||||||
from transformers import ( # noqa: F401
|
from transformers import ( # noqa: F401
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
@@ -24,17 +23,13 @@ from transformers import ( # noqa: F401
|
|||||||
|
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from peft import PeftConfig # noqa: F401
|
||||||
|
|
||||||
def load_model_config(cfg):
|
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||||
model_config_name = cfg.base_model_config or cfg.base_model
|
|
||||||
trust_remote_code: bool = False or cfg.trust_remote_code
|
|
||||||
return AutoConfig.from_pretrained(
|
|
||||||
model_config_name, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(cfg):
|
def load_tokenizer(cfg):
|
||||||
@@ -91,10 +86,8 @@ def load_tokenizer(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
cfg: DictDefault,
|
cfg, tokenizer
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
inference: bool = False,
|
|
||||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
|
||||||
"""
|
"""
|
||||||
Load a model for a given configuration and tokenizer.
|
Load a model for a given configuration and tokenizer.
|
||||||
"""
|
"""
|
||||||
@@ -104,9 +97,14 @@ def load_model(
|
|||||||
|
|
||||||
# TODO refactor as a kwarg
|
# TODO refactor as a kwarg
|
||||||
load_in_8bit = cfg.load_in_8bit
|
load_in_8bit = cfg.load_in_8bit
|
||||||
|
cfg.is_llama_derived_model = (
|
||||||
|
"llama" in base_model
|
||||||
|
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
||||||
|
or cfg.is_llama_derived_model
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||||
if cfg.device not in ["mps", "cpu"] and not inference:
|
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
replace_llama_attn_with_flash_attn,
|
replace_llama_attn_with_flash_attn,
|
||||||
)
|
)
|
||||||
@@ -148,7 +146,7 @@ def load_model(
|
|||||||
if (
|
if (
|
||||||
cfg.is_llama_derived_model
|
cfg.is_llama_derived_model
|
||||||
and (cfg.max_packed_sequence_len or cfg.sample_packing)
|
and (cfg.max_packed_sequence_len or cfg.sample_packing)
|
||||||
and not inference
|
and not cfg.inference
|
||||||
):
|
):
|
||||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||||
|
|
||||||
@@ -426,15 +424,15 @@ def load_model(
|
|||||||
return model, lora_config
|
return model, lora_config
|
||||||
|
|
||||||
|
|
||||||
def load_adapter(model, cfg, adapter, inference=False):
|
def load_adapter(model, cfg, adapter):
|
||||||
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
# type: (PreTrainedModel, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
|
|
||||||
if adapter is None:
|
if adapter is None:
|
||||||
return model, None
|
return model, None
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
model.enable_input_require_grads()
|
model.enable_input_require_grads()
|
||||||
if adapter in ["lora", "qlora"]:
|
if adapter in ["lora", "qlora"]:
|
||||||
return load_lora(model, cfg, inference=inference)
|
return load_lora(model, cfg)
|
||||||
if adapter == "llama-adapter":
|
if adapter == "llama-adapter":
|
||||||
return load_llama_adapter(model, cfg)
|
return load_llama_adapter(model, cfg)
|
||||||
|
|
||||||
@@ -466,8 +464,12 @@ def load_llama_adapter(model, cfg):
|
|||||||
return model, peft_config
|
return model, peft_config
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_names(model):
|
def find_all_linear_names(bits, model):
|
||||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
cls = (
|
||||||
|
bnb.nn.Linear4bit
|
||||||
|
if bits == 4
|
||||||
|
else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
|
||||||
|
)
|
||||||
lora_module_names = set()
|
lora_module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, cls):
|
if isinstance(module, cls):
|
||||||
@@ -480,15 +482,21 @@ def find_all_linear_names(model):
|
|||||||
return list(lora_module_names)
|
return list(lora_module_names)
|
||||||
|
|
||||||
|
|
||||||
def load_lora(model, cfg, inference=False):
|
def load_lora(model, cfg):
|
||||||
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
|
|
||||||
from peft import LoraConfig, PeftModel, get_peft_model
|
from peft import LoraConfig, PeftModel, get_peft_model
|
||||||
|
|
||||||
lora_target_modules = list(cfg.lora_target_modules or [])
|
lora_target_modules = list(cfg.lora_target_modules or [])
|
||||||
|
|
||||||
if cfg.lora_target_linear:
|
if cfg.lora_target_linear:
|
||||||
linear_names = find_all_linear_names(model)
|
bits = None
|
||||||
|
if cfg.load_in_4bit:
|
||||||
|
bits = 4
|
||||||
|
elif cfg.load_in_8bit:
|
||||||
|
bits = 8
|
||||||
|
|
||||||
|
linear_names = find_all_linear_names(bits, model)
|
||||||
LOG.info(f"found linear modules: {repr(linear_names)}")
|
LOG.info(f"found linear modules: {repr(linear_names)}")
|
||||||
lora_target_modules = list(set(lora_target_modules + linear_names))
|
lora_target_modules = list(set(lora_target_modules + linear_names))
|
||||||
|
|
||||||
@@ -508,7 +516,7 @@ def load_lora(model, cfg, inference=False):
|
|||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(
|
||||||
model,
|
model,
|
||||||
cfg.lora_model_dir,
|
cfg.lora_model_dir,
|
||||||
is_trainable=(not inference),
|
is_trainable=not cfg.inference,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
|
|||||||
@@ -471,9 +471,6 @@ def setup_fsdp_envs(cfg):
|
|||||||
os.environ[
|
os.environ[
|
||||||
"FSDP_TRANSFORMER_CLS_TO_WRAP"
|
"FSDP_TRANSFORMER_CLS_TO_WRAP"
|
||||||
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
||||||
from axolotl.monkeypatch.fsdp import replace_fsdp_state_dict_type
|
|
||||||
|
|
||||||
replace_fsdp_state_dict_type()
|
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
@@ -650,12 +647,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
callbacks.append(SaveBetterTransformerModelCallback)
|
callbacks.append(SaveBetterTransformerModelCallback)
|
||||||
|
|
||||||
data_collator_kwargs = {
|
data_collator_kwargs = {
|
||||||
"padding": True, # True/"longest" is the default
|
"padding": True,
|
||||||
}
|
}
|
||||||
if cfg.pad_to_sequence_len:
|
if cfg.collator_pad_to_longest:
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
|
data_collator_kwargs["padding"] = "longest"
|
||||||
cfg.sequence_len / 64
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||||
|
|||||||
Reference in New Issue
Block a user