Compare commits

..

23 Commits

Author SHA1 Message Date
Wing Lian
c3de28942c fix for gather across multiple gpus
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-08-29 06:57:28 -07:00
Wing Lian
45848a9285 gather benchmarks from all ranks 2023-08-28 11:29:59 -04:00
Wing Lian
d6cea18034 improve support for customized dataset for bench evals 2023-08-28 06:03:53 -04:00
Wing Lian
606846e0a5 missing transformers import 2023-08-28 05:43:19 -04:00
Wing Lian
a6c9223114 more fixes 2023-08-28 05:39:13 -04:00
Wing Lian
8b16ecd448 updated dataset 2023-08-28 05:39:13 -04:00
Wing Lian
f5db88a10d fixes 2023-08-28 05:39:13 -04:00
Wing Lian
99d844f215 benchmark callback has its own dataloader and collator 2023-08-28 05:39:13 -04:00
Wing Lian
aefd4d74fa better handling when no subjects 2023-08-28 05:39:13 -04:00
Wing Lian
24b0e93235 dataset handling and aggregate across benchmark 2023-08-28 05:39:13 -04:00
Wing Lian
2455254b92 more fixes 2023-08-28 05:39:13 -04:00
Wing Lian
918e040601 rename mmlu to bench 2023-08-28 05:39:13 -04:00
Wing Lian
ef062d8fcb more fixes 2023-08-28 05:39:13 -04:00
Wing Lian
d4c8b66f3d fix elif and add better messaging 2023-08-28 05:39:13 -04:00
Wing Lian
64e9824d3e fix the data file 2023-08-28 05:39:13 -04:00
Wing Lian
1134654c98 sample benchmarks, ensure we drop long samples 2023-08-28 05:39:13 -04:00
Wing Lian
2fc756c289 fix mmlu evals 2023-08-28 05:39:13 -04:00
Wing Lian
943b84c490 another callback fix for collator max len attribute 2023-08-28 05:39:13 -04:00
Wing Lian
6f166464d8 include metrics in callback 2023-08-28 05:39:13 -04:00
Wing Lian
e3b07402a7 make sure to define all the explicit positional args 2023-08-28 05:39:13 -04:00
Wing Lian
8d3c8a3eab default to mmlu-zs 2023-08-28 05:39:13 -04:00
Wing Lian
c30120e684 use hf dataset for mmlu evals 2023-08-28 05:39:13 -04:00
Wing Lian
9aed60fa54 add mmlu callback 2023-08-28 05:39:12 -04:00
15 changed files with 237 additions and 412 deletions

View File

@@ -163,8 +163,6 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
```
</details>
- Windows: Please use WSL or Docker!
### Dataset
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
@@ -330,15 +328,6 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
name: enron_emails
type: completion # format from earlier
# huggingface repo with multiple named configurations/subsets
datasets:
- path: bigcode/commitpackft
name:
- ruby
- python
- typescript
type: ... # unimplemented custom format
# local
datasets:
- path: data.jsonl # or json
@@ -418,10 +407,6 @@ fp16: true
# Use CUDA tf32
tf32: true # require >=ampere
# No AMP (automatic mixed precision)
bfloat16: true # require >=ampere
float16: true
# a list of one or more datasets to finetune the model with
datasets:
# hf dataset repo | "json" for local dataset, make sure to fill data_files
@@ -474,9 +459,6 @@ dataset_shard_idx:
# the maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048
sequence_len: 2048
# pad inputs so each step uses constant sized buffers
# this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len:
# max sequence length to concatenate training samples together up to
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED
@@ -625,14 +607,12 @@ fsdp_config:
# Deepspeed config path
deepspeed:
# Advanced DDP Arguments
ddp_timeout:
ddp_bucket_cap_mb:
ddp_broadcast_buffers:
# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:
# Set padding for data collator to 'longest'
collator_pad_to_longest:
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
pretraining_dataset:

View File

@@ -35,7 +35,10 @@
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"betas": [
0.9,
0.95
],
"eps": 1e-8,
"weight_decay": "auto"
}

View File

@@ -11,7 +11,7 @@ RUN apt-get update && \
WORKDIR /workspace
RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main"
RUN pip3 install --force-reinstall "peft @ git+https://github.com/huggingface/peft.git@main"
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN cd axolotl && \

View File

@@ -47,3 +47,4 @@ local_rank:
gradient_checkpointing: true
fsdp:
fsdp_config:
collator_pad_to_longest: true

View File

@@ -25,4 +25,3 @@ rouge-score==0.1.2
scipy
scikit-learn==1.2.2
pynvml
art

View File

@@ -4,28 +4,27 @@ import importlib
import logging
import os
import random
import signal
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import fire
import torch
import transformers
import yaml
# add src to the pythonpath so we don't need to pip install this
from art import text2art
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta, train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_model_config, load_tokenizer
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -38,12 +37,15 @@ LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def print_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
ascii_text += f" x {suffix}"
ascii_art = text2art(" axolotl", font=font)
def print_axolotl_text_art():
ascii_art = """
dP dP dP
88 88 88
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
88' `88 `8bd8' 88' `88 88 88' `88 88 88
88. .88 .d88b. 88. .88 88 88. .88 88 88
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
"""
if is_main_process():
print(ascii_art)
@@ -58,45 +60,7 @@ def get_multi_line_input() -> Optional[str]:
return instruction
def do_merge_lora(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info("saving merged model")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def shard(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_inference(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
@@ -171,10 +135,6 @@ def choose_config(path: Path):
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return yaml_files[0]
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
@@ -197,7 +157,12 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
return not any(el in list2 for el in list1)
def load_cfg(config: Path = Path("examples/"), **kwargs):
def train(
config: Path = Path("configs/"),
prepare_ds_only: bool = False,
**kwargs,
):
print_axolotl_text_art()
if Path(config).is_dir():
config = choose_config(config)
@@ -216,72 +181,146 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
else:
cfg[k] = kwargs[k]
model_config = load_model_config(cfg)
# figure out if the model is llama
cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
)
validate_config(cfg)
normalize_config(cfg)
setup_wandb_env_vars(cfg)
return cfg
def load_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
# load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
if (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
): # don't need to load dataset for these
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
if cli_args.debug or cfg.debug:
if cfg.debug or "debug" in kwargs:
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
if prepare_ds_only:
LOG.info("Finished preparing dataset. Exiting...")
return
# Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...")
model, peft_config = load_model(cfg, tokenizer)
safe_serialization = cfg.save_safetensors is True
if "merge_lora" in kwargs and cfg.adapter is not None:
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info("saving merged model")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return
if cfg.inference:
LOG.info("calling do_inference function")
prompter: Optional[str] = "AlpacaPrompter"
if "prompter" in kwargs:
if kwargs["prompter"] == "None":
prompter = None
else:
prompter = kwargs["prompter"]
do_inference(cfg, model, tokenizer, prompter=prompter)
return
if "shard" in kwargs:
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
return
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
cfg.resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = cfg.resume_from_checkpoint
trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
model.config.use_cache = False
def do_cli(config: Path = Path("examples/"), **kwargs):
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if parsed_cli_args.inference:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.merge_lora:
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.shard:
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
if torch.__version__ >= "2" and sys.platform != "win32":
LOG.info("Compiling torch model")
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
sys.exit(0)
signal.signal(
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(cfg.output_dir)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()
else:
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
trainer.save_model(cfg.output_dir)
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if __name__ == "__main__":
fire.Fire(do_cli)
fire.Fire(train)

View File

@@ -1,43 +0,0 @@
"""
shared module for cli specific things
"""
import logging
from dataclasses import dataclass, field
from typing import Optional
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger("axolotl.common.cli")
@dataclass
class TrainerCliArgs:
"""
dataclass representing the various non-training arguments
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=5)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prepare_ds_only: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
def load_model_and_tokenizer(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model and (optionally) peft_config...")
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
return model, tokenizer

View File

@@ -1,139 +0,0 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import logging
import os
import signal
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
# add src to the pythonpath so we don't need to pip install this
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = logging.getLogger("axolotl.train")
@dataclass
class TrainDatasetMeta:
"""
dataclass to capture the dataset specific options for training
"""
train_dataset: Dataset
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
def train(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
dataset_meta: TrainDatasetMeta,
):
# load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...")
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
safe_serialization = cfg.save_safetensors is True
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
cfg.resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = cfg.resume_from_checkpoint
trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
LOG.info("Compiling torch model")
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
sys.exit(0)
signal.signal(
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(cfg.output_dir)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()
else:
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return model, tokenizer
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
trainer.save_model(cfg.output_dir)
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
return model, tokenizer

View File

@@ -11,7 +11,6 @@ import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
from accelerate.state import PartialState
from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
@@ -25,9 +24,11 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import (
barrier,
gather_scalar_from_all_ranks,
get_world_size,
is_main_process,
zero_first,
)
if TYPE_CHECKING:
@@ -35,7 +36,6 @@ if TYPE_CHECKING:
LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100
dist_state = PartialState()
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -210,7 +210,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
"subject": example["subject"],
}
with dist_state.main_process_first():
with zero_first(is_main_process()):
bench_dataset = bench_dataset.map(tokenize_evals)
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
@@ -258,7 +258,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
bench_names[s]["preds"].append(p)
bench_names[s]["refs"].append(r)
dist_state.wait_for_everyone()
barrier()
local_bench_names = bench_names
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
# Gather results from all GPUs to GPU 0
@@ -275,7 +275,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
else:
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
results = {f"{bench_split}_bench_loss": bench_loss}
results = {"bench_loss": bench_loss}
# Combine results from all GPUs
combined_bench_names: Dict[str, Dict[str, List]] = {}
@@ -287,8 +287,6 @@ def bench_eval_callback_factory(trainer, tokenizer):
combined_bench_names[name]["preds"].extend(data["preds"])
bench_scores = []
bench_refs = []
bench_preds = []
for (
bench_name
) in combined_bench_names: # pylint: disable=consider-using-dict-items
@@ -296,20 +294,15 @@ def bench_eval_callback_factory(trainer, tokenizer):
references=combined_bench_names[bench_name]["refs"],
predictions=combined_bench_names[bench_name]["preds"],
)["accuracy"]
bench_refs.extend(combined_bench_names[bench_name]["refs"])
bench_preds.extend(combined_bench_names[bench_name]["preds"])
if not pd.isna(bench_score):
results[
f"{bench_split}_bench_accuracy_{bench_name}"
f"bench_{bench_split}_accuracy_{bench_name}"
] = bench_score
bench_scores.append(bench_score)
else:
results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0
results[f"bench_{bench_split}_accuracy_{bench_name}"] = 0.0
bench_scores.append(0.0)
results[f"{bench_split}_bench_average_accuracy"] = np.mean(bench_scores)
results[f"{bench_split}_bench_total_accuracy"] = accuracy.compute(
references=bench_refs, predictions=bench_preds
)["accuracy"]
results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores)
trainer.log(results)
return BenchEvalCallback

View File

@@ -7,7 +7,6 @@ from pathlib import Path
from typing import Tuple, Union
import torch
from accelerate.state import PartialState
from datasets import (
Dataset,
DatasetDict,
@@ -43,6 +42,7 @@ from axolotl.prompters import (
SummarizeTLDRPrompter,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
@@ -50,12 +50,11 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl")
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
state = PartialState()
def prepare_dataset(cfg, tokenizer):
if not cfg.pretraining_dataset:
with state.main_process_first():
with zero_first(is_main_process()):
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
@@ -70,7 +69,7 @@ def prepare_dataset(cfg, tokenizer):
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
with state.main_process_first():
with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
@@ -135,17 +134,8 @@ def load_tokenized_prepared_datasets(
seed = 42
datasets = []
def for_d_in_datasets(dataset_configs):
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
else:
yield dataset
# pylint: disable=invalid-name
for d in for_d_in_datasets(cfg.datasets):
for d in cfg.datasets:
ds: Union[Dataset, DatasetDict] = None
ds_from_hub = False
try:
@@ -508,7 +498,7 @@ def load_prepare_datasets(
to_hash_test.encode(), usedforsecurity=False
).hexdigest()
with state.main_process_first():
with zero_first(is_main_process()):
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,

View File

@@ -1,27 +1,29 @@
"""
utility helpers for distributed checks
"""
import os
from contextlib import contextmanager
import torch
import torch.distributed as dist
from accelerate import DistributedType
from accelerate.state import PartialState
from accelerate.utils import wait_for_everyone
from accelerate import Accelerator
accelerate = None # pylint: disable=invalid-name
state = PartialState()
def load_accelerate():
global accelerate # pylint: disable=global-statement
accelerate = Accelerator()
def is_distributed():
"""
Check if distributed training is initialized.
"""
return state.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_CPU,
DistributedType.DEEPSPEED,
DistributedType.FSDP,
)
global accelerate # pylint: disable=global-statement
if not accelerate:
accelerate = Accelerator()
return dist.is_available() and dist.is_initialized()
def barrier():
@@ -29,19 +31,34 @@ def barrier():
Acts as a barrier to wait for all processes. This ensures that all processes
reach the barrier before proceeding further.
"""
wait_for_everyone()
if is_distributed():
dist.barrier()
def is_main_process() -> bool:
def is_main_process():
"""
Check if the current process is the main process.
If not in distributed mode, always return True.
"""
return state.is_main_process
if not is_distributed():
return True
return dist.get_rank() == 0
def get_world_size() -> int:
return state.num_processes
def get_world_size():
return int(os.getenv("WORLD_SIZE", "1"))
@contextmanager
def zero_first(is_main):
"""
runs the wrapped context so that rank 0 runs first before other ranks
"""
if not is_main: # other ranks wait first
barrier()
yield
if is_main: # then rank 0 waits after it has run the context
barrier()
def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
@@ -59,7 +76,7 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
if not state.is_main_process:
if not is_main_process():
dist.gather(value_tensor, dst=0)
else:
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]

View File

@@ -5,13 +5,12 @@ import logging
import math
import os
from pathlib import Path
from typing import Optional, Tuple # noqa: F401
from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
import bitsandbytes as bnb
import torch
import transformers
from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig
from transformers import ( # noqa: F401
AutoConfig,
AutoModelForCausalLM,
@@ -24,17 +23,13 @@ from transformers import ( # noqa: F401
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl")
if TYPE_CHECKING:
from peft import PeftConfig # noqa: F401
def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code: bool = False or cfg.trust_remote_code
return AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code
)
from axolotl.utils.dict import DictDefault # noqa: F401
def load_tokenizer(cfg):
@@ -91,10 +86,8 @@ def load_tokenizer(cfg):
def load_model(
cfg: DictDefault,
tokenizer: PreTrainedTokenizerBase,
inference: bool = False,
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
cfg, tokenizer
): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
"""
Load a model for a given configuration and tokenizer.
"""
@@ -104,9 +97,14 @@ def load_model(
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
cfg.is_llama_derived_model = (
"llama" in base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
or cfg.is_llama_derived_model
)
if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and not inference:
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
@@ -148,7 +146,7 @@ def load_model(
if (
cfg.is_llama_derived_model
and (cfg.max_packed_sequence_len or cfg.sample_packing)
and not inference
and not cfg.inference
):
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
@@ -371,7 +369,7 @@ def load_model(
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model):
if needs_fa2_dtype and (cfg.flash_attention and cfg.is_llama_derived_model):
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules():
if "norm" in name:
@@ -426,15 +424,15 @@ def load_model(
return model, lora_config
def load_adapter(model, cfg, adapter, inference=False):
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
def load_adapter(model, cfg, adapter):
# type: (PreTrainedModel, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference)
return load_lora(model, cfg)
if adapter == "llama-adapter":
return load_llama_adapter(model, cfg)
@@ -466,8 +464,12 @@ def load_llama_adapter(model, cfg):
return model, peft_config
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
def find_all_linear_names(bits, model):
cls = (
bnb.nn.Linear4bit
if bits == 4
else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
)
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
@@ -480,15 +482,21 @@ def find_all_linear_names(model):
return list(lora_module_names)
def load_lora(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
def load_lora(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import LoraConfig, PeftModel, get_peft_model
lora_target_modules = list(cfg.lora_target_modules or [])
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
bits = None
if cfg.load_in_4bit:
bits = 4
elif cfg.load_in_8bit:
bits = 8
linear_names = find_all_linear_names(bits, model)
LOG.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = list(set(lora_target_modules + linear_names))
@@ -508,7 +516,7 @@ def load_lora(model, cfg, inference=False):
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
is_trainable=(not inference),
is_trainable=not cfg.inference,
)
else:
model = get_peft_model(model, lora_config)

View File

@@ -8,13 +8,13 @@ from termcolor import colored
LOG = logging.getLogger("axolotl")
def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
def check_dataset_labels(dataset, tokenizer):
# the dataset is already shuffled, so let's just check the first 5 elements
for idx in range(num_examples):
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
for idx in range(5):
check_example_labels(dataset[idx], tokenizer)
def check_example_labels(example, tokenizer, text_only=False):
def check_example_labels(example, tokenizer):
# Get the input_ids, labels, and attention_mask from the dataset
input_ids = example["input_ids"]
labels = example["labels"]
@@ -29,10 +29,8 @@ def check_example_labels(example, tokenizer, text_only=False):
decoded_input_token = tokenizer.decode(input_id)
# Choose the color based on whether the label has the ignore value or not
color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
colored_token = colored(decoded_input_token, color) + (
not text_only
and colored(f"({label_id}, {mask}, {input_id})", "white")
or ""
colored_token = colored(decoded_input_token, color) + colored(
f"({label_id}, {mask}, {input_id})", "white"
)
colored_tokens.append(colored_token)

View File

@@ -361,7 +361,7 @@ def add_position_ids(sample):
def drop_long_seq(sample, sequence_len=2048):
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
return len(sample["input_ids"]) <= sequence_len
@contextmanager
@@ -401,16 +401,6 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
cfg.total_num_tokens = total_num_tokens
if not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
.apply(lambda x: np.sum(np.array(x) != -100))
.sum()
)
LOG.info(f"`total_supervised_tokens: {total_supervised_tokens}`")
cfg.total_supervised_tokens = total_supervised_tokens
if cfg.sample_packing_eff_est:
total_num_steps = (
# match count to len est in dataloader
@@ -589,15 +579,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
# DDP Config
if cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
if cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = cfg.ddp_bucket_cap_mb
if cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs["ddp_broadcast_buffers"] = cfg.ddp_broadcast_buffers
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len,
@@ -666,12 +647,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
"padding": True,
}
if cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
cfg.sequence_len / 64
)
if cfg.collator_pad_to_longest:
data_collator_kwargs["padding"] = "longest"
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html