refactor scripts/finetune.py into new cli modules (#550)
* refactor scripts/finetune.py into new cli modules * continue to support scripts/finetune.py * update readme with updated cli commands * Update scripts/finetune.py Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
22
README.md
22
README.md
@@ -76,11 +76,11 @@ pip3 install -e .[flash-attn]
|
|||||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||||
|
|
||||||
# finetune lora
|
# finetune lora
|
||||||
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
|
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||||
|
|
||||||
# inference
|
# inference
|
||||||
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
|
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||||
--inference --lora_model_dir="./lora-out"
|
--lora_model_dir="./lora-out"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
@@ -674,14 +674,14 @@ strict:
|
|||||||
|
|
||||||
Run
|
Run
|
||||||
```bash
|
```bash
|
||||||
accelerate launch scripts/finetune.py your_config.yml
|
accelerate launch -m axolotl.cli.train your_config.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Multi-GPU
|
#### Multi-GPU
|
||||||
|
|
||||||
You can optionally pre-tokenize dataset with the following before finetuning:
|
You can optionally pre-tokenize dataset with the following before finetuning:
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
|
CUDA_VISIBLE_DEVICES="" accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only
|
||||||
```
|
```
|
||||||
|
|
||||||
##### Config
|
##### Config
|
||||||
@@ -720,16 +720,16 @@ Pass the appropriate flag to the train command:
|
|||||||
|
|
||||||
- Pretrained LORA:
|
- Pretrained LORA:
|
||||||
```bash
|
```bash
|
||||||
--inference --lora_model_dir="./lora-output-dir"
|
python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir"
|
||||||
```
|
```
|
||||||
- Full weights finetune:
|
- Full weights finetune:
|
||||||
```bash
|
```bash
|
||||||
--inference --base_model="./completed-model"
|
python -m axolotl.cli.inference examples/your_config.yml --base_model="./completed-model"
|
||||||
```
|
```
|
||||||
- Full weights finetune w/ a prompt from a text file:
|
- Full weights finetune w/ a prompt from a text file:
|
||||||
```bash
|
```bash
|
||||||
cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
|
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
|
||||||
--base_model="./completed-model" --inference --prompter=None --load_in_8bit=True
|
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
||||||
```
|
```
|
||||||
|
|
||||||
### Merge LORA to base
|
### Merge LORA to base
|
||||||
@@ -737,13 +737,13 @@ Pass the appropriate flag to the train command:
|
|||||||
Add below flag to train command above
|
Add below flag to train command above
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
||||||
```
|
```
|
||||||
|
|
||||||
If you run out of CUDA memory, you can try to merge in system RAM with
|
If you run out of CUDA memory, you can try to merge in system RAM with
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
|
CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
|
||||||
```
|
```
|
||||||
|
|
||||||
## Common Errors 🧰
|
## Common Errors 🧰
|
||||||
|
|||||||
@@ -1,269 +1,34 @@
|
|||||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||||
|
|
||||||
import importlib
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
|
||||||
import transformers
|
import transformers
|
||||||
import yaml
|
|
||||||
|
|
||||||
# add src to the pythonpath so we don't need to pip install this
|
from axolotl.cli import (
|
||||||
from accelerate.commands.config import config_args
|
check_accelerate_default_config,
|
||||||
from art import text2art
|
do_inference,
|
||||||
from transformers import GenerationConfig, TextStreamer
|
do_merge_lora,
|
||||||
|
load_cfg,
|
||||||
|
load_datasets,
|
||||||
|
print_axolotl_text_art,
|
||||||
|
)
|
||||||
|
from axolotl.cli.shard import shard
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
LOG = logging.getLogger("axolotl.scripts.finetune")
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
from axolotl.train import TrainDatasetMeta, train
|
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
|
||||||
from axolotl.utils.data import prepare_dataset
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
|
||||||
from axolotl.utils.models import load_tokenizer
|
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
||||||
src_dir = os.path.join(project_root, "src")
|
|
||||||
sys.path.insert(0, src_dir)
|
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = logging.getLogger("axolotl.scripts")
|
|
||||||
|
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
|
||||||
|
|
||||||
|
|
||||||
def print_axolotl_text_art(suffix=None):
|
|
||||||
font = "nancyj"
|
|
||||||
ascii_text = " axolotl"
|
|
||||||
if suffix:
|
|
||||||
ascii_text += f" x {suffix}"
|
|
||||||
ascii_art = text2art(" axolotl", font=font)
|
|
||||||
|
|
||||||
if is_main_process():
|
|
||||||
print(ascii_art)
|
|
||||||
|
|
||||||
|
|
||||||
def get_multi_line_input() -> Optional[str]:
|
|
||||||
print("Give me an instruction (Ctrl + D to finish): ")
|
|
||||||
instruction = ""
|
|
||||||
for line in sys.stdin:
|
|
||||||
instruction += line # pylint: disable=consider-using-join
|
|
||||||
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
|
|
||||||
return instruction
|
|
||||||
|
|
||||||
|
|
||||||
def do_merge_lora(
|
|
||||||
*,
|
|
||||||
cfg: DictDefault,
|
|
||||||
cli_args: TrainerCliArgs,
|
|
||||||
):
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
|
||||||
|
|
||||||
LOG.info("running merge of LoRA with base model")
|
|
||||||
model = model.merge_and_unload()
|
|
||||||
model.to(dtype=torch.float16)
|
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
|
||||||
LOG.info("saving merged model")
|
|
||||||
model.save_pretrained(
|
|
||||||
str(Path(cfg.output_dir) / "merged"),
|
|
||||||
safe_serialization=safe_serialization,
|
|
||||||
)
|
|
||||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
|
||||||
|
|
||||||
|
|
||||||
def shard(
|
|
||||||
*,
|
|
||||||
cfg: DictDefault,
|
|
||||||
cli_args: TrainerCliArgs,
|
|
||||||
):
|
|
||||||
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
|
||||||
LOG.debug("Re-saving model w/ sharding")
|
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
|
||||||
|
|
||||||
|
|
||||||
def do_inference(
|
|
||||||
*,
|
|
||||||
cfg: DictDefault,
|
|
||||||
cli_args: TrainerCliArgs,
|
|
||||||
):
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
|
||||||
prompter = cli_args.prompter
|
|
||||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
|
||||||
|
|
||||||
for token, symbol in default_tokens.items():
|
|
||||||
# If the token isn't already specified in the config, add it
|
|
||||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
|
||||||
tokenizer.add_special_tokens({token: symbol})
|
|
||||||
|
|
||||||
prompter_module = None
|
|
||||||
if prompter:
|
|
||||||
prompter_module = getattr(
|
|
||||||
importlib.import_module("axolotl.prompters"), prompter
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.landmark_attention:
|
|
||||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
|
||||||
|
|
||||||
set_model_mem_id(model, tokenizer)
|
|
||||||
model.set_mem_cache_args(
|
|
||||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model.to(cfg.device)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
print("=" * 80)
|
|
||||||
# support for multiline inputs
|
|
||||||
instruction = get_multi_line_input()
|
|
||||||
if not instruction:
|
|
||||||
return
|
|
||||||
if prompter_module:
|
|
||||||
prompt: str = next(
|
|
||||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt = instruction.strip()
|
|
||||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
|
||||||
|
|
||||||
print("=" * 40)
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
generation_config = GenerationConfig(
|
|
||||||
repetition_penalty=1.1,
|
|
||||||
max_new_tokens=1024,
|
|
||||||
temperature=0.9,
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=40,
|
|
||||||
bos_token_id=tokenizer.bos_token_id,
|
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
do_sample=True,
|
|
||||||
use_cache=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
output_attentions=False,
|
|
||||||
output_hidden_states=False,
|
|
||||||
output_scores=False,
|
|
||||||
)
|
|
||||||
streamer = TextStreamer(tokenizer)
|
|
||||||
generated = model.generate(
|
|
||||||
inputs=batch["input_ids"].to(cfg.device),
|
|
||||||
generation_config=generation_config,
|
|
||||||
streamer=streamer,
|
|
||||||
)
|
|
||||||
print("=" * 40)
|
|
||||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
|
||||||
|
|
||||||
|
|
||||||
def choose_config(path: Path):
|
|
||||||
yaml_files = list(path.glob("*.yml"))
|
|
||||||
|
|
||||||
if not yaml_files:
|
|
||||||
raise ValueError(
|
|
||||||
"No YAML config files found in the specified directory. Are you using a .yml extension?"
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(yaml_files) == 1:
|
|
||||||
print(f"Using default YAML file '{yaml_files[0]}'")
|
|
||||||
return yaml_files[0]
|
|
||||||
|
|
||||||
print("Choose a YAML file:")
|
|
||||||
for idx, file in enumerate(yaml_files):
|
|
||||||
print(f"{idx + 1}. {file}")
|
|
||||||
|
|
||||||
chosen_file = None
|
|
||||||
while chosen_file is None:
|
|
||||||
try:
|
|
||||||
choice = int(input("Enter the number of your choice: "))
|
|
||||||
if 1 <= choice <= len(yaml_files):
|
|
||||||
chosen_file = yaml_files[choice - 1]
|
|
||||||
else:
|
|
||||||
print("Invalid choice. Please choose a number from the list.")
|
|
||||||
except ValueError:
|
|
||||||
print("Invalid input. Please enter a number.")
|
|
||||||
|
|
||||||
return chosen_file
|
|
||||||
|
|
||||||
|
|
||||||
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
|
|
||||||
return not any(el in list2 for el in list1)
|
|
||||||
|
|
||||||
|
|
||||||
def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|
||||||
if Path(config).is_dir():
|
|
||||||
config = choose_config(config)
|
|
||||||
|
|
||||||
# load the config from the yaml file
|
|
||||||
with open(config, encoding="utf-8") as file:
|
|
||||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
|
||||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
|
||||||
# then overwrite the value
|
|
||||||
cfg_keys = cfg.keys()
|
|
||||||
for k, _ in kwargs.items():
|
|
||||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
|
||||||
if k in cfg_keys or not cfg.strict:
|
|
||||||
# handle booleans
|
|
||||||
if isinstance(cfg[k], bool):
|
|
||||||
cfg[k] = bool(kwargs[k])
|
|
||||||
else:
|
|
||||||
cfg[k] = kwargs[k]
|
|
||||||
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
normalize_config(cfg)
|
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
def load_datasets(
|
|
||||||
*,
|
|
||||||
cfg: DictDefault,
|
|
||||||
cli_args: TrainerCliArgs,
|
|
||||||
) -> TrainDatasetMeta:
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
|
||||||
|
|
||||||
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
|
|
||||||
|
|
||||||
if cli_args.debug or cfg.debug:
|
|
||||||
LOG.info("check_dataset_labels...")
|
|
||||||
check_dataset_labels(
|
|
||||||
train_dataset.select(
|
|
||||||
[
|
|
||||||
random.randrange(0, len(train_dataset) - 1) # nosec
|
|
||||||
for _ in range(cli_args.debug_num_examples)
|
|
||||||
]
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
num_examples=cli_args.debug_num_examples,
|
|
||||||
text_only=cli_args.debug_text_only,
|
|
||||||
)
|
|
||||||
|
|
||||||
return TrainDatasetMeta(
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
eval_dataset=eval_dataset,
|
|
||||||
total_num_steps=total_num_steps,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def check_accelerate_default_config():
|
|
||||||
if Path(config_args.default_yaml_config_file).exists():
|
|
||||||
LOG.warning(
|
|
||||||
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
|
LOG.warning(
|
||||||
|
str(
|
||||||
|
PendingDeprecationWarning(
|
||||||
|
"scripts/finetune.py will be replaced with calling axolotl.cli.train"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
|
|||||||
249
src/axolotl/cli/__init__.py
Normal file
249
src/axolotl/cli/__init__.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
# add src to the pythonpath so we don't need to pip install this
|
||||||
|
from accelerate.commands.config import config_args
|
||||||
|
from art import text2art
|
||||||
|
from transformers import GenerationConfig, TextStreamer
|
||||||
|
|
||||||
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||||
|
from axolotl.logging_config import configure_logging
|
||||||
|
from axolotl.train import TrainDatasetMeta
|
||||||
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
|
from axolotl.utils.data import prepare_dataset
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
from axolotl.utils.models import load_tokenizer
|
||||||
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
src_dir = os.path.join(project_root, "src")
|
||||||
|
sys.path.insert(0, src_dir)
|
||||||
|
|
||||||
|
configure_logging()
|
||||||
|
LOG = logging.getLogger("axolotl.scripts")
|
||||||
|
|
||||||
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
|
||||||
|
|
||||||
|
def print_axolotl_text_art(suffix=None):
|
||||||
|
font = "nancyj"
|
||||||
|
ascii_text = " axolotl"
|
||||||
|
if suffix:
|
||||||
|
ascii_text += f" x {suffix}"
|
||||||
|
ascii_art = text2art(" axolotl", font=font)
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
print(ascii_art)
|
||||||
|
|
||||||
|
|
||||||
|
def get_multi_line_input() -> Optional[str]:
|
||||||
|
print("Give me an instruction (Ctrl + D to finish): ")
|
||||||
|
instruction = ""
|
||||||
|
for line in sys.stdin:
|
||||||
|
instruction += line # pylint: disable=consider-using-join
|
||||||
|
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
|
||||||
|
return instruction
|
||||||
|
|
||||||
|
|
||||||
|
def do_merge_lora(
|
||||||
|
*,
|
||||||
|
cfg: DictDefault,
|
||||||
|
cli_args: TrainerCliArgs,
|
||||||
|
):
|
||||||
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||||
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
|
LOG.info("running merge of LoRA with base model")
|
||||||
|
model = model.merge_and_unload()
|
||||||
|
model.to(dtype=torch.float16)
|
||||||
|
|
||||||
|
if cfg.local_rank == 0:
|
||||||
|
LOG.info("saving merged model")
|
||||||
|
model.save_pretrained(
|
||||||
|
str(Path(cfg.output_dir) / "merged"),
|
||||||
|
safe_serialization=safe_serialization,
|
||||||
|
)
|
||||||
|
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
|
|
||||||
|
|
||||||
|
def do_inference(
|
||||||
|
*,
|
||||||
|
cfg: DictDefault,
|
||||||
|
cli_args: TrainerCliArgs,
|
||||||
|
):
|
||||||
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||||
|
prompter = cli_args.prompter
|
||||||
|
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||||
|
|
||||||
|
for token, symbol in default_tokens.items():
|
||||||
|
# If the token isn't already specified in the config, add it
|
||||||
|
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||||
|
tokenizer.add_special_tokens({token: symbol})
|
||||||
|
|
||||||
|
prompter_module = None
|
||||||
|
if prompter:
|
||||||
|
prompter_module = getattr(
|
||||||
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.landmark_attention:
|
||||||
|
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||||
|
|
||||||
|
set_model_mem_id(model, tokenizer)
|
||||||
|
model.set_mem_cache_args(
|
||||||
|
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||||
|
)
|
||||||
|
|
||||||
|
model = model.to(cfg.device)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
print("=" * 80)
|
||||||
|
# support for multiline inputs
|
||||||
|
instruction = get_multi_line_input()
|
||||||
|
if not instruction:
|
||||||
|
return
|
||||||
|
if prompter_module:
|
||||||
|
prompt: str = next(
|
||||||
|
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = instruction.strip()
|
||||||
|
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||||
|
|
||||||
|
print("=" * 40)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
repetition_penalty=1.1,
|
||||||
|
max_new_tokens=1024,
|
||||||
|
temperature=0.9,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=40,
|
||||||
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
do_sample=True,
|
||||||
|
use_cache=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
output_scores=False,
|
||||||
|
)
|
||||||
|
streamer = TextStreamer(tokenizer)
|
||||||
|
generated = model.generate(
|
||||||
|
inputs=batch["input_ids"].to(cfg.device),
|
||||||
|
generation_config=generation_config,
|
||||||
|
streamer=streamer,
|
||||||
|
)
|
||||||
|
print("=" * 40)
|
||||||
|
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def choose_config(path: Path):
|
||||||
|
yaml_files = list(path.glob("*.yml"))
|
||||||
|
|
||||||
|
if not yaml_files:
|
||||||
|
raise ValueError(
|
||||||
|
"No YAML config files found in the specified directory. Are you using a .yml extension?"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(yaml_files) == 1:
|
||||||
|
print(f"Using default YAML file '{yaml_files[0]}'")
|
||||||
|
return yaml_files[0]
|
||||||
|
|
||||||
|
print("Choose a YAML file:")
|
||||||
|
for idx, file in enumerate(yaml_files):
|
||||||
|
print(f"{idx + 1}. {file}")
|
||||||
|
|
||||||
|
chosen_file = None
|
||||||
|
while chosen_file is None:
|
||||||
|
try:
|
||||||
|
choice = int(input("Enter the number of your choice: "))
|
||||||
|
if 1 <= choice <= len(yaml_files):
|
||||||
|
chosen_file = yaml_files[choice - 1]
|
||||||
|
else:
|
||||||
|
print("Invalid choice. Please choose a number from the list.")
|
||||||
|
except ValueError:
|
||||||
|
print("Invalid input. Please enter a number.")
|
||||||
|
|
||||||
|
return chosen_file
|
||||||
|
|
||||||
|
|
||||||
|
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
|
||||||
|
return not any(el in list2 for el in list1)
|
||||||
|
|
||||||
|
|
||||||
|
def load_cfg(config: Path = Path("examples/"), **kwargs):
|
||||||
|
if Path(config).is_dir():
|
||||||
|
config = choose_config(config)
|
||||||
|
|
||||||
|
# load the config from the yaml file
|
||||||
|
with open(config, encoding="utf-8") as file:
|
||||||
|
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||||
|
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||||
|
# then overwrite the value
|
||||||
|
cfg_keys = cfg.keys()
|
||||||
|
for k, _ in kwargs.items():
|
||||||
|
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||||
|
if k in cfg_keys or not cfg.strict:
|
||||||
|
# handle booleans
|
||||||
|
if isinstance(cfg[k], bool):
|
||||||
|
cfg[k] = bool(kwargs[k])
|
||||||
|
else:
|
||||||
|
cfg[k] = kwargs[k]
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
normalize_config(cfg)
|
||||||
|
|
||||||
|
setup_wandb_env_vars(cfg)
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def load_datasets(
|
||||||
|
*,
|
||||||
|
cfg: DictDefault,
|
||||||
|
cli_args: TrainerCliArgs,
|
||||||
|
) -> TrainDatasetMeta:
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
|
||||||
|
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
|
||||||
|
|
||||||
|
if cli_args.debug or cfg.debug:
|
||||||
|
LOG.info("check_dataset_labels...")
|
||||||
|
check_dataset_labels(
|
||||||
|
train_dataset.select(
|
||||||
|
[
|
||||||
|
random.randrange(0, len(train_dataset) - 1) # nosec
|
||||||
|
for _ in range(cli_args.debug_num_examples)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
tokenizer,
|
||||||
|
num_examples=cli_args.debug_num_examples,
|
||||||
|
text_only=cli_args.debug_text_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
return TrainDatasetMeta(
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
total_num_steps=total_num_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_accelerate_default_config():
|
||||||
|
if Path(config_args.default_yaml_config_file).exists():
|
||||||
|
LOG.warning(
|
||||||
|
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
|
||||||
|
)
|
||||||
26
src/axolotl/cli/inference.py
Normal file
26
src/axolotl/cli/inference.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
CLI to run inference on a trained model
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
|
||||||
|
|
||||||
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
print_axolotl_text_art()
|
||||||
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
|
return_remaining_strings=True
|
||||||
|
)
|
||||||
|
parsed_cli_args.inference = True
|
||||||
|
|
||||||
|
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
|
|
||||||
|
fire.Fire(do_cli)
|
||||||
26
src/axolotl/cli/merge_lora.py
Normal file
26
src/axolotl/cli/merge_lora.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
CLI to run merge a trained LoRA into a base model
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
|
||||||
|
|
||||||
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
print_axolotl_text_art()
|
||||||
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
|
return_remaining_strings=True
|
||||||
|
)
|
||||||
|
parsed_cli_args.merge_lora = True
|
||||||
|
|
||||||
|
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
|
|
||||||
|
fire.Fire(do_cli)
|
||||||
41
src/axolotl/cli/shard.py
Normal file
41
src/axolotl/cli/shard.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""
|
||||||
|
CLI to shard a trained model into 10GiB chunks
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from axolotl.cli import load_cfg, print_axolotl_text_art
|
||||||
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.scripts")
|
||||||
|
|
||||||
|
|
||||||
|
def shard(
|
||||||
|
*,
|
||||||
|
cfg: DictDefault,
|
||||||
|
cli_args: TrainerCliArgs,
|
||||||
|
):
|
||||||
|
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||||
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
LOG.debug("Re-saving model w/ sharding")
|
||||||
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
|
|
||||||
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
print_axolotl_text_art()
|
||||||
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
|
return_remaining_strings=True
|
||||||
|
)
|
||||||
|
parsed_cli_args.shard = True
|
||||||
|
|
||||||
|
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
|
|
||||||
|
fire.Fire(do_cli)
|
||||||
35
src/axolotl/cli/train.py
Normal file
35
src/axolotl/cli/train.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
CLI to run training on a model
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from axolotl.cli import (
|
||||||
|
check_accelerate_default_config,
|
||||||
|
load_cfg,
|
||||||
|
load_datasets,
|
||||||
|
print_axolotl_text_art,
|
||||||
|
)
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
|
||||||
|
|
||||||
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
print_axolotl_text_art()
|
||||||
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
check_accelerate_default_config()
|
||||||
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
|
return_remaining_strings=True
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
if parsed_cli_args.prepare_ds_only:
|
||||||
|
return
|
||||||
|
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
|
||||||
|
fire.Fire(do_cli)
|
||||||
Reference in New Issue
Block a user