Compare commits

...

21 Commits

Author SHA1 Message Date
Wing Lian
83d904a27d fix the context manager call
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-09-03 22:49:28 -04:00
Wing Lian
5e4a760ad8 start to swap out for accelerate partial state 2023-09-03 22:41:29 -04:00
Maxime
1991946c5a fix: bad dtype for full finetune (#504)
* fix: bad dtype for full finetune

* Update src/axolotl/utils/models.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Update models.py

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-09-01 07:11:45 -07:00
NanoCode012
f51c9c56c6 Fix(doc): Inform Windows users to use WSL/docker (#518) 2023-09-01 00:08:21 -07:00
Wing Lian
7710e81f50 log supervised token count (#448) 2023-08-31 15:45:23 -07:00
Tom Jobbins
48434bec54 Debug tokenization output: Add ability to output text only (no tokens), and/or specify num samples to see (#511) 2023-08-31 14:26:52 -07:00
Jan Philipp Harries
396a7a74fc Added advanced DDP args (#515)
* add ddp_config

* add advanced ddp config

* add ddp_config

* add advanced ddp config

---------

Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
2023-08-31 10:37:47 -07:00
Wing Lian
b21e4a20fe split train from other cli options (#503) 2023-08-30 22:01:47 -07:00
Alpay Ariyak
42f9642792 Changed Bench Eval to report metrics correctly by split. Added total accuracy and renamed previously used bench_accuracy to bench_average_accuracy. (#512)
* Added "eval_" prefix

* Added total bench accuracy and renamed the previous one to bench_average_accuracy. Changed naming to use bench_split instead of always using eval_ prefix.
2023-08-30 22:00:50 -07:00
Wing Lian
c56b450cf5 drop empty tokenized rows too (#509) 2023-08-30 06:55:26 -07:00
Aman Gupta Karmani
1e07c162f1 set zero3 optimizer betas to auto so they inherit from HF trainer config (#507) 2023-08-30 08:10:33 -04:00
Wing Lian
76576323df add eval benchmark callback (#441)
* add mmlu callback

* use hf dataset for mmlu evals

* default to mmlu-zs

* make sure to define all the explicit positional args

* include metrics in callback

* another callback fix for collator max len attribute

* fix mmlu evals

* sample benchmarks, ensure we drop long samples

* fix the data file

* fix elif and add better messaging

* more fixes

* rename mmlu to bench

* more fixes

* dataset handling and aggregate across benchmark

* better handling when no subjects

* benchmark callback has its own dataloader and collator

* fixes

* updated dataset

* more fixes

* missing transformers import

* improve support for customized dataset for bench evals

* gather benchmarks from all ranks

* fix for gather across multiple gpus
2023-08-29 13:24:19 -07:00
Wing Lian
548787daae customizable ascii art (#506) 2023-08-29 10:13:42 -07:00
Wing Lian
5ac3392075 support for datasets with multiple names (#480)
* support for datasets with multiple names

* update docs
2023-08-29 06:18:17 -07:00
Aman Gupta Karmani
e356b297cb remove --force-reinstall from Dockerfile to ensure correct pytorch version (#492) 2023-08-29 06:17:51 -07:00
NanoCode012
48c56470d0 Fix(doc): Clarify no amp to full yaml docs (#496) 2023-08-29 06:17:37 -07:00
Maxime
36b2e1cfee tweak: use default config file when only one file is present (#501) 2023-08-29 06:17:10 -07:00
Wing Lian
125cccb786 Refactor train cfg cli (#499)
* wip to cleanup cfg cli options

* fix launcher

* fix cli args
2023-08-29 05:37:53 -07:00
Aman Karmani
fd55bc87e2 use math.ceil instead of round /cc #498 2023-08-29 01:03:41 +00:00
Birch-san
8e197f6fb4 pad_to_worst_case_seq_len boolean, for testing memory limits (#498)
* pad_to_worst_case_seq_len boolean, for testing memory limits

* remove collator_pad_to_longest option since it does nothing

see docs: https://huggingface.co/docs/transformers/main_classes/data_collator#transformers.DataCollatorWithPadding.padding

True and "longest" mean the same thing

* rename to `pad_to_sequence_len, and ensure 64 alignment

---------

Co-authored-by: Aman Karmani <aman@tmm1.net>
2023-08-28 18:47:16 -04:00
Aman Karmani
267b7b24e5 simplify linear layer locator 2023-08-28 09:45:16 -04:00
15 changed files with 716 additions and 222 deletions

View File

@@ -163,6 +163,8 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
``` ```
</details> </details>
- Windows: Please use WSL or Docker!
### Dataset ### Dataset
Axolotl supports a variety of dataset formats. Below are some of the formats you can use. Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
@@ -328,6 +330,15 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
name: enron_emails name: enron_emails
type: completion # format from earlier type: completion # format from earlier
# huggingface repo with multiple named configurations/subsets
datasets:
- path: bigcode/commitpackft
name:
- ruby
- python
- typescript
type: ... # unimplemented custom format
# local # local
datasets: datasets:
- path: data.jsonl # or json - path: data.jsonl # or json
@@ -407,6 +418,10 @@ fp16: true
# Use CUDA tf32 # Use CUDA tf32
tf32: true # require >=ampere tf32: true # require >=ampere
# No AMP (automatic mixed precision)
bfloat16: true # require >=ampere
float16: true
# a list of one or more datasets to finetune the model with # a list of one or more datasets to finetune the model with
datasets: datasets:
# hf dataset repo | "json" for local dataset, make sure to fill data_files # hf dataset repo | "json" for local dataset, make sure to fill data_files
@@ -459,6 +474,9 @@ dataset_shard_idx:
# the maximum length of an input to train with, this should typically be less than 2048 # the maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048 # as most models have a token/context limit of 2048
sequence_len: 2048 sequence_len: 2048
# pad inputs so each step uses constant sized buffers
# this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len:
# max sequence length to concatenate training samples together up to # max sequence length to concatenate training samples together up to
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED # FutureWarning: This will soon be DEPRECATED
@@ -607,12 +625,14 @@ fsdp_config:
# Deepspeed config path # Deepspeed config path
deepspeed: deepspeed:
# Advanced DDP Arguments
ddp_timeout:
ddp_bucket_cap_mb:
ddp_broadcast_buffers:
# Path to torch distx for optim 'adamw_anyprecision' # Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path: torchdistx_path:
# Set padding for data collator to 'longest'
collator_pad_to_longest:
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize # Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
pretraining_dataset: pretraining_dataset:

View File

@@ -35,10 +35,7 @@
"type": "AdamW", "type": "AdamW",
"params": { "params": {
"lr": "auto", "lr": "auto",
"betas": [ "betas": "auto",
0.9,
0.95
],
"eps": 1e-8, "eps": 1e-8,
"weight_decay": "auto" "weight_decay": "auto"
} }

View File

@@ -11,7 +11,7 @@ RUN apt-get update && \
WORKDIR /workspace WORKDIR /workspace
RUN pip3 install --force-reinstall "peft @ git+https://github.com/huggingface/peft.git@main" RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main"
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN cd axolotl && \ RUN cd axolotl && \

View File

@@ -47,4 +47,3 @@ local_rank:
gradient_checkpointing: true gradient_checkpointing: true
fsdp: fsdp:
fsdp_config: fsdp_config:
collator_pad_to_longest: true

View File

@@ -4,6 +4,7 @@ transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.41.1 bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
addict addict
evaluate
fire fire
PyYAML>=6.0 PyYAML>=6.0
datasets datasets
@@ -24,3 +25,4 @@ rouge-score==0.1.2
scipy scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml pynvml
art

View File

@@ -4,27 +4,28 @@ import importlib
import logging import logging
import os import os
import random import random
import signal
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import fire import fire
import torch import torch
import transformers
import yaml import yaml
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
from optimum.bettertransformer import BetterTransformer from art import text2art
from transformers import GenerationConfig, TextStreamer from transformers import GenerationConfig, TextStreamer
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta, train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model_config, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.wandb import setup_wandb_env_vars from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -37,15 +38,12 @@ LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def print_axolotl_text_art(): def print_axolotl_text_art(suffix=None):
ascii_art = """ font = "nancyj"
dP dP dP ascii_text = " axolotl"
88 88 88 if suffix:
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88 ascii_text += f" x {suffix}"
88' `88 `8bd8' 88' `88 88 88' `88 88 88 ascii_art = text2art(" axolotl", font=font)
88. .88 .d88b. 88. .88 88 88. .88 88 88
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
"""
if is_main_process(): if is_main_process():
print(ascii_art) print(ascii_art)
@@ -60,7 +58,45 @@ def get_multi_line_input() -> Optional[str]:
return instruction return instruction
def do_inference(cfg, model, tokenizer, prompter: Optional[str]): def do_merge_lora(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info("saving merged model")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def shard(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_inference(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"} default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items(): for token, symbol in default_tokens.items():
@@ -135,6 +171,10 @@ def choose_config(path: Path):
"No YAML config files found in the specified directory. Are you using a .yml extension?" "No YAML config files found in the specified directory. Are you using a .yml extension?"
) )
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return yaml_files[0]
print("Choose a YAML file:") print("Choose a YAML file:")
for idx, file in enumerate(yaml_files): for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}") print(f"{idx + 1}. {file}")
@@ -157,12 +197,7 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
return not any(el in list2 for el in list1) return not any(el in list2 for el in list1)
def train( def load_cfg(config: Path = Path("examples/"), **kwargs):
config: Path = Path("configs/"),
prepare_ds_only: bool = False,
**kwargs,
):
print_axolotl_text_art()
if Path(config).is_dir(): if Path(config).is_dir():
config = choose_config(config) config = choose_config(config)
@@ -181,146 +216,72 @@ def train(
else: else:
cfg[k] = kwargs[k] cfg[k] = kwargs[k]
model_config = load_model_config(cfg)
# figure out if the model is llama
cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
)
validate_config(cfg) validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
setup_wandb_env_vars(cfg) setup_wandb_env_vars(cfg)
return cfg
# load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") def load_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
if ( train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
): # don't need to load dataset for these
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
if cfg.debug or "debug" in kwargs: if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")
check_dataset_labels( check_dataset_labels(
train_dataset.select( train_dataset.select(
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec [
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
), ),
tokenizer, tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
) )
if prepare_ds_only: return TrainDatasetMeta(
LOG.info("Finished preparing dataset. Exiting...") train_dataset=train_dataset,
return eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
# Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...")
model, peft_config = load_model(cfg, tokenizer)
safe_serialization = cfg.save_safetensors is True
if "merge_lora" in kwargs and cfg.adapter is not None:
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info("saving merged model")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return
if cfg.inference:
LOG.info("calling do_inference function")
prompter: Optional[str] = "AlpacaPrompter"
if "prompter" in kwargs:
if kwargs["prompter"] == "None":
prompter = None
else:
prompter = kwargs["prompter"]
do_inference(cfg, model, tokenizer, prompter=prompter)
return
if "shard" in kwargs:
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
return
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
cfg.resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = cfg.resume_from_checkpoint
trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
) )
model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32": def do_cli(config: Path = Path("examples/"), **kwargs):
LOG.info("Compiling torch model") print_axolotl_text_art()
model = torch.compile(model) parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
# go ahead and presave, so we have the adapter config available to inspect parsed_cli_args, _ = parser.parse_args_into_dataclasses(
if peft_config: return_remaining_strings=True
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") )
peft_config.save_pretrained(cfg.output_dir) if parsed_cli_args.inference:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model elif parsed_cli_args.merge_lora:
if cfg.local_rank == 0: do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.shard:
def terminate_handler(_, __, model): shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
sys.exit(0)
signal.signal(
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(cfg.output_dir)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else: else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint) dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()
else:
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
trainer.save_model(cfg.output_dir)
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(train) fire.Fire(do_cli)

View File

43
src/axolotl/common/cli.py Normal file
View File

@@ -0,0 +1,43 @@
"""
shared module for cli specific things
"""
import logging
from dataclasses import dataclass, field
from typing import Optional
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger("axolotl.common.cli")
@dataclass
class TrainerCliArgs:
"""
dataclass representing the various non-training arguments
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=5)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prepare_ds_only: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
def load_model_and_tokenizer(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model and (optionally) peft_config...")
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
return model, tokenizer

139
src/axolotl/train.py Normal file
View File

@@ -0,0 +1,139 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import logging
import os
import signal
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
# add src to the pythonpath so we don't need to pip install this
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = logging.getLogger("axolotl.train")
@dataclass
class TrainDatasetMeta:
"""
dataclass to capture the dataset specific options for training
"""
train_dataset: Dataset
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
def train(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
dataset_meta: TrainDatasetMeta,
):
# load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...")
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
safe_serialization = cfg.save_safetensors is True
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
cfg.resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = cfg.resume_from_checkpoint
trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
LOG.info("Compiling torch model")
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
sys.exit(0)
signal.signal(
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(cfg.output_dir)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()
else:
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return model, tokenizer
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
trainer.save_model(cfg.output_dir)
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
return model, tokenizer

View File

@@ -1,9 +1,20 @@
"""Callbacks for Trainer class""" """Callbacks for Trainer class"""
from __future__ import annotations
import logging import logging
import os import os
from typing import TYPE_CHECKING, Dict, List
import evaluate
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
from accelerate.state import PartialState
from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
from transformers import ( from transformers import (
TrainerCallback, TrainerCallback,
TrainerControl, TrainerControl,
@@ -13,8 +24,18 @@ from transformers import (
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import (
gather_scalar_from_all_ranks,
get_world_size,
is_main_process,
)
if TYPE_CHECKING:
from axolotl.utils.trainer import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks") LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100
dist_state = PartialState()
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -96,3 +117,199 @@ class GPUStatsCallback(
log_gpu_memory_usage(LOG, "while training", self.cfg.device) log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True self.logged = True
return control return control
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [
tokenizer("A", add_special_tokens=False).input_ids[0],
tokenizer("B", add_special_tokens=False).input_ids[0],
tokenizer("C", add_special_tokens=False).input_ids[0],
tokenizer("D", add_special_tokens=False).input_ids[0],
tokenizer("E", add_special_tokens=False).input_ids[0],
tokenizer("F", add_special_tokens=False).input_ids[0],
tokenizer("G", add_special_tokens=False).input_ids[0],
]
bench_split = "eval"
def transform_bench_subject(example):
# Split on ':' and trim whitespace
parts = example["subject"].split(":")
first_part = (
parts[0].strip().lower().replace("-", "_")
) # Lowercase the first part
second_part = (
parts[1].strip().replace("-", "_") if len(parts) > 1 else "all"
) # Replace hyphens with underscores
# Return the transformed values
return {"name": first_part, "subject": second_part}
if trainer.args.bench_dataset == "mmlu-zs":
bench_dataset = load_dataset(
"openaccess-ai-collective/mmlu-evals",
data_files={
"eval": "zero_shot_mmlu_val.json",
"test": "zero_shot_mmlu_test.json",
},
)
# bench_dataset = bench_dataset.remove_columns("subject")
# MMLU Five-shot (Eval/Test only)
elif trainer.args.bench_dataset in ["mmlu", "mmlu-fs"]:
bench_dataset = load_dataset(
"openaccess-ai-collective/mmlu-evals",
data_files={
"eval": "five_shot_mmlu_val.json",
"test": "five_shot_mmlu_test.json",
},
)
# bench_dataset = bench_dataset.remove_columns('subject')
elif "/" in trainer.args.bench_dataset:
bench_ds = trainer.args.bench_dataset
bench_ds_name = "/".join(bench_ds.split("/", 2)[:2])
bench_ds_data_file = "/".join(bench_ds.split("/", 2)[2:])
bench_dataset = load_dataset(
bench_ds_name,
data_files={
"eval": bench_ds_data_file,
},
)
bench_dataset["eval"] = bench_dataset["eval"].map(transform_bench_subject)
else:
raise ValueError(
f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args"
)
bench_dataset = bench_dataset[trainer.args.bench_split]
if trainer.args.max_bench_samples is not None:
bench_dataset = bench_dataset.select(range(trainer.args.max_bench_samples))
def tokenize_evals(example):
source = f"{tokenizer.bos_token}{example['input']}"
target = f"{example['output']}{tokenizer.eos_token}"
tokenized_source = tokenizer(
source,
max_length=2048,
truncation=True,
add_special_tokens=False,
)
tokenized_target = tokenizer(
target,
max_length=2048,
truncation=True,
add_special_tokens=False,
)
input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"]
labels = [IGNORE_INDEX] * len(tokenized_source["input_ids"]) + tokenized_target[
"input_ids"
]
return {
"input_ids": input_ids,
"labels": labels,
"subject": example["subject"],
}
with dist_state.main_process_first():
bench_dataset = bench_dataset.map(tokenize_evals)
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
class BenchEvalCallback(TrainerCallback):
"""
TrainerCallback that runs the MMLU evals
"""
def on_evaluate(
self,
args: AxolotlTrainingArguments,
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl, # pylint: disable=unused-argument
metrics: Dict[str, float], # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
data_loader = trainer.get_bench_dataloader(
bench_dataset.remove_columns(["input", "subject", "output", "name"])
)
trainer.model.eval()
preds, refs = [], []
loss_bench = 0
for batch in tqdm(data_loader, total=len(data_loader)):
(loss, logits, labels) = trainer.prediction_step(
trainer.model,
batch,
prediction_loss_only=False,
)
# There are two tokens, the output, and eos token.
for i, logit in enumerate(logits):
label_non_zero_id = (batch["labels"][i] != IGNORE_INDEX).nonzero()[
0
][0]
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
preds.append(torch.argmax(logit_abcd).item())
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
refs += [
abcd_idx.index(label) if label in abcd_idx else -1
for label in labels.tolist()
]
loss_bench += loss.item()
# Extract results by subject.
bench_name = bench_dataset["name"]
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
bench_names[s]["preds"].append(p)
bench_names[s]["refs"].append(r)
dist_state.wait_for_everyone()
local_bench_names = bench_names
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
# Gather results from all GPUs to GPU 0
loss_bench_ranks = gather_scalar_from_all_ranks(
lambda: loss_bench, get_world_size()
)
len_data_loader_ranks = gather_scalar_from_all_ranks(
lambda: len(data_loader), get_world_size()
)
if not is_main_process():
dist.gather_object(local_bench_names, dst=0)
else:
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
results = {f"{bench_split}_bench_loss": bench_loss}
# Combine results from all GPUs
combined_bench_names: Dict[str, Dict[str, List]] = {}
for bench_name in gathered_bench_names:
for name, data in bench_name.items():
if name not in combined_bench_names:
combined_bench_names[name] = {"refs": [], "preds": []}
combined_bench_names[name]["refs"].extend(data["refs"])
combined_bench_names[name]["preds"].extend(data["preds"])
bench_scores = []
bench_refs = []
bench_preds = []
for (
bench_name
) in combined_bench_names: # pylint: disable=consider-using-dict-items
bench_score = accuracy.compute(
references=combined_bench_names[bench_name]["refs"],
predictions=combined_bench_names[bench_name]["preds"],
)["accuracy"]
bench_refs.extend(combined_bench_names[bench_name]["refs"])
bench_preds.extend(combined_bench_names[bench_name]["preds"])
if not pd.isna(bench_score):
results[
f"{bench_split}_bench_accuracy_{bench_name}"
] = bench_score
bench_scores.append(bench_score)
else:
results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0
bench_scores.append(0.0)
results[f"{bench_split}_bench_average_accuracy"] = np.mean(bench_scores)
results[f"{bench_split}_bench_total_accuracy"] = accuracy.compute(
references=bench_refs, predictions=bench_preds
)["accuracy"]
trainer.log(results)
return BenchEvalCallback

View File

@@ -7,6 +7,7 @@ from pathlib import Path
from typing import Tuple, Union from typing import Tuple, Union
import torch import torch
from accelerate.state import PartialState
from datasets import ( from datasets import (
Dataset, Dataset,
DatasetDict, DatasetDict,
@@ -42,7 +43,6 @@ from axolotl.prompters import (
SummarizeTLDRPrompter, SummarizeTLDRPrompter,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.trainer import ( from axolotl.utils.trainer import (
calculate_total_num_steps, calculate_total_num_steps,
process_datasets_for_packing, process_datasets_for_packing,
@@ -50,11 +50,12 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
state = PartialState()
def prepare_dataset(cfg, tokenizer): def prepare_dataset(cfg, tokenizer):
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
with zero_first(is_main_process()): with state.main_process_first():
train_dataset, eval_dataset = load_prepare_datasets( train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
) )
@@ -69,7 +70,7 @@ def prepare_dataset(cfg, tokenizer):
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")
eval_dataset = None eval_dataset = None
with zero_first(is_main_process()): with state.main_process_first():
train_dataset, eval_dataset = process_datasets_for_packing( train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset cfg, train_dataset, eval_dataset
) )
@@ -134,8 +135,17 @@ def load_tokenized_prepared_datasets(
seed = 42 seed = 42
datasets = [] datasets = []
def for_d_in_datasets(dataset_configs):
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
else:
yield dataset
# pylint: disable=invalid-name # pylint: disable=invalid-name
for d in cfg.datasets: for d in for_d_in_datasets(cfg.datasets):
ds: Union[Dataset, DatasetDict] = None ds: Union[Dataset, DatasetDict] = None
ds_from_hub = False ds_from_hub = False
try: try:
@@ -498,7 +508,7 @@ def load_prepare_datasets(
to_hash_test.encode(), usedforsecurity=False to_hash_test.encode(), usedforsecurity=False
).hexdigest() ).hexdigest()
with zero_first(is_main_process()): with state.main_process_first():
dataset = dataset.train_test_split( dataset = dataset.train_test_split(
test_size=cfg.val_set_size, test_size=cfg.val_set_size,
shuffle=False, shuffle=False,

View File

@@ -1,27 +1,27 @@
""" """
utility helpers for distributed checks utility helpers for distributed checks
""" """
from contextlib import contextmanager import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate import Accelerator from accelerate import DistributedType
from accelerate.state import PartialState
from accelerate.utils import wait_for_everyone
accelerate = None # pylint: disable=invalid-name accelerate = None # pylint: disable=invalid-name
state = PartialState()
def load_accelerate():
global accelerate # pylint: disable=global-statement
accelerate = Accelerator()
def is_distributed(): def is_distributed():
""" """
Check if distributed training is initialized. Check if distributed training is initialized.
""" """
global accelerate # pylint: disable=global-statement return state.distributed_type in (
if not accelerate: DistributedType.MULTI_GPU,
accelerate = Accelerator() DistributedType.MULTI_CPU,
return dist.is_available() and dist.is_initialized() DistributedType.DEEPSPEED,
DistributedType.FSDP,
)
def barrier(): def barrier():
@@ -29,27 +29,48 @@ def barrier():
Acts as a barrier to wait for all processes. This ensures that all processes Acts as a barrier to wait for all processes. This ensures that all processes
reach the barrier before proceeding further. reach the barrier before proceeding further.
""" """
if is_distributed(): wait_for_everyone()
dist.barrier()
def is_main_process(): def is_main_process() -> bool:
""" """
Check if the current process is the main process. Check if the current process is the main process.
If not in distributed mode, always return True. If not in distributed mode, always return True.
""" """
if not is_distributed(): return state.is_main_process
return True
return dist.get_rank() == 0
@contextmanager def get_world_size() -> int:
def zero_first(is_main): return state.num_processes
def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
""" """
runs the wrapped context so that rank 0 runs first before other ranks Run a callable 'fn' on all ranks and gather the results on the specified rank.
Args:
- fn (callable): A function that computes the value. This should not have any side effects.
- rank (int, optional): The rank that gathers the values. Default is 0.
- world_size (int, optional): Total number of processes in the current distributed setup.
Returns:
- A list of computed values from all ranks if on the gathering rank, otherwise None.
""" """
if not is_main: # other ranks wait first value_scalar = fn()
barrier() value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
yield
if is_main: # then rank 0 waits after it has run the context if not state.is_main_process:
barrier() dist.gather(value_tensor, dst=0)
else:
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
# Convert tensors back to their original type (int or float)
gathered_values = []
for tensor in gathered_tensors:
if tensor == tensor.int():
gathered_values.append(int(tensor.item()))
else:
gathered_values.append(float(tensor.item()))
return gathered_values
return None

View File

@@ -5,12 +5,13 @@ import logging
import math import math
import os import os
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401 from typing import Optional, Tuple # noqa: F401
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig
from transformers import ( # noqa: F401 from transformers import ( # noqa: F401
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
@@ -23,13 +24,17 @@ from transformers import ( # noqa: F401
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
if TYPE_CHECKING:
from peft import PeftConfig # noqa: F401
from axolotl.utils.dict import DictDefault # noqa: F401 def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code: bool = False or cfg.trust_remote_code
return AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code
)
def load_tokenizer(cfg): def load_tokenizer(cfg):
@@ -86,8 +91,10 @@ def load_tokenizer(cfg):
def load_model( def load_model(
cfg, tokenizer cfg: DictDefault,
): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]] tokenizer: PreTrainedTokenizerBase,
inference: bool = False,
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
""" """
Load a model for a given configuration and tokenizer. Load a model for a given configuration and tokenizer.
""" """
@@ -97,14 +104,9 @@ def load_model(
# TODO refactor as a kwarg # TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit load_in_8bit = cfg.load_in_8bit
cfg.is_llama_derived_model = (
"llama" in base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
or cfg.is_llama_derived_model
)
if cfg.is_llama_derived_model and cfg.flash_attention: if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and not cfg.inference: if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import ( from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn, replace_llama_attn_with_flash_attn,
) )
@@ -146,7 +148,7 @@ def load_model(
if ( if (
cfg.is_llama_derived_model cfg.is_llama_derived_model
and (cfg.max_packed_sequence_len or cfg.sample_packing) and (cfg.max_packed_sequence_len or cfg.sample_packing)
and not cfg.inference and not inference
): ):
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
@@ -369,7 +371,7 @@ def load_model(
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility. # convert them back to fp16/bf16 for flash-attn compatibility.
if needs_fa2_dtype and (cfg.flash_attention and cfg.is_llama_derived_model): if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model):
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules(): for name, module in model.named_modules():
if "norm" in name: if "norm" in name:
@@ -424,15 +426,15 @@ def load_model(
return model, lora_config return model, lora_config
def load_adapter(model, cfg, adapter): def load_adapter(model, cfg, adapter, inference=False):
# type: (PreTrainedModel, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] # type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
if adapter is None: if adapter is None:
return model, None return model, None
if hasattr(model, "enable_input_require_grads"): if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads() model.enable_input_require_grads()
if adapter in ["lora", "qlora"]: if adapter in ["lora", "qlora"]:
return load_lora(model, cfg) return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter": if adapter == "llama-adapter":
return load_llama_adapter(model, cfg) return load_llama_adapter(model, cfg)
@@ -464,12 +466,8 @@ def load_llama_adapter(model, cfg):
return model, peft_config return model, peft_config
def find_all_linear_names(bits, model): def find_all_linear_names(model):
cls = ( cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
bnb.nn.Linear4bit
if bits == 4
else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
)
lora_module_names = set() lora_module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, cls): if isinstance(module, cls):
@@ -482,21 +480,15 @@ def find_all_linear_names(bits, model):
return list(lora_module_names) return list(lora_module_names)
def load_lora(model, cfg): def load_lora(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] # type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import LoraConfig, PeftModel, get_peft_model from peft import LoraConfig, PeftModel, get_peft_model
lora_target_modules = list(cfg.lora_target_modules or []) lora_target_modules = list(cfg.lora_target_modules or [])
if cfg.lora_target_linear: if cfg.lora_target_linear:
bits = None linear_names = find_all_linear_names(model)
if cfg.load_in_4bit:
bits = 4
elif cfg.load_in_8bit:
bits = 8
linear_names = find_all_linear_names(bits, model)
LOG.info(f"found linear modules: {repr(linear_names)}") LOG.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = list(set(lora_target_modules + linear_names)) lora_target_modules = list(set(lora_target_modules + linear_names))
@@ -516,7 +508,7 @@ def load_lora(model, cfg):
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
cfg.lora_model_dir, cfg.lora_model_dir,
is_trainable=not cfg.inference, is_trainable=(not inference),
) )
else: else:
model = get_peft_model(model, lora_config) model = get_peft_model(model, lora_config)

View File

@@ -8,13 +8,13 @@ from termcolor import colored
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def check_dataset_labels(dataset, tokenizer): def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
# the dataset is already shuffled, so let's just check the first 5 elements # the dataset is already shuffled, so let's just check the first 5 elements
for idx in range(5): for idx in range(num_examples):
check_example_labels(dataset[idx], tokenizer) check_example_labels(dataset[idx], tokenizer, text_only=text_only)
def check_example_labels(example, tokenizer): def check_example_labels(example, tokenizer, text_only=False):
# Get the input_ids, labels, and attention_mask from the dataset # Get the input_ids, labels, and attention_mask from the dataset
input_ids = example["input_ids"] input_ids = example["input_ids"]
labels = example["labels"] labels = example["labels"]
@@ -29,8 +29,10 @@ def check_example_labels(example, tokenizer):
decoded_input_token = tokenizer.decode(input_id) decoded_input_token = tokenizer.decode(input_id)
# Choose the color based on whether the label has the ignore value or not # Choose the color based on whether the label has the ignore value or not
color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
colored_token = colored(decoded_input_token, color) + colored( colored_token = colored(decoded_input_token, color) + (
f"({label_id}, {mask}, {input_id})", "white" not text_only
and colored(f"({label_id}, {mask}, {input_id})", "white")
or ""
) )
colored_tokens.append(colored_token) colored_tokens.append(colored_token)

View File

@@ -12,9 +12,15 @@ from typing import Optional, Union
import numpy as np import numpy as np
import torch.cuda import torch.cuda
import transformers
from datasets import Dataset, set_caching_enabled from datasets import Dataset, set_caching_enabled
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler from torch.utils.data import (
DataLoader,
DistributedSampler,
RandomSampler,
SequentialSampler,
)
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import SequentialDistributedSampler from transformers.trainer_pt_utils import SequentialDistributedSampler
@@ -23,6 +29,7 @@ from axolotl.utils.callbacks import (
GPUStatsCallback, GPUStatsCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
SavePeftModelCallback, SavePeftModelCallback,
bench_eval_callback_factory,
) )
from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader from axolotl.utils.dataloader import MultipackDistributedDataloader
@@ -127,6 +134,27 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None, default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
) )
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
class AxolotlTrainer(Trainer): class AxolotlTrainer(Trainer):
@@ -136,6 +164,10 @@ class AxolotlTrainer(Trainer):
args = None # type: AxolotlTrainingArguments args = None # type: AxolotlTrainingArguments
def __init__(self, *args, bench_data_collator=None, **kwargs):
self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs)
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
): ):
@@ -226,6 +258,31 @@ class AxolotlTrainer(Trainer):
) )
return super().get_eval_dataloader(eval_dataset) return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> Union[DataLoader, MultipackDistributedDataloader]:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False): def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc # use one's weighted cross entropy loss calc
# if self.args.sample_packing: # if self.args.sample_packing:
@@ -304,7 +361,7 @@ def add_position_ids(sample):
def drop_long_seq(sample, sequence_len=2048): def drop_long_seq(sample, sequence_len=2048):
return len(sample["input_ids"]) <= sequence_len return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
@contextmanager @contextmanager
@@ -344,6 +401,16 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`") LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
cfg.total_num_tokens = total_num_tokens cfg.total_num_tokens = total_num_tokens
if not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
.apply(lambda x: np.sum(np.array(x) != -100))
.sum()
)
LOG.info(f"`total_supervised_tokens: {total_supervised_tokens}`")
cfg.total_supervised_tokens = total_supervised_tokens
if cfg.sample_packing_eff_est: if cfg.sample_packing_eff_est:
total_num_steps = ( total_num_steps = (
# match count to len est in dataloader # match count to len est in dataloader
@@ -517,6 +584,20 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"steps" if cfg.save_steps else "epoch" "steps" if cfg.save_steps else "epoch"
) )
if cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
if cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
# DDP Config
if cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
if cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = cfg.ddp_bucket_cap_mb
if cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs["ddp_broadcast_buffers"] = cfg.ddp_broadcast_buffers
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1, max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len, max_seq_length=cfg.sequence_len,
@@ -585,10 +666,12 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
callbacks.append(SaveBetterTransformerModelCallback) callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = { data_collator_kwargs = {
"padding": True, "padding": True, # True/"longest" is the default
} }
if cfg.collator_pad_to_longest: if cfg.pad_to_sequence_len:
data_collator_kwargs["padding"] = "longest" data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
cfg.sequence_len / 64
)
else: else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
@@ -627,8 +710,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
return_tensors="pt", return_tensors="pt",
**data_collator_kwargs, **data_collator_kwargs,
), ),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=callbacks, callbacks=callbacks,
**trainer_kwargs, **trainer_kwargs,
) )
if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
return trainer return trainer