Compare commits
21 Commits
feature/re
...
embeddings
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
31079cd5fd | ||
|
|
41ecb451c2 | ||
|
|
3c2ad00d07 | ||
|
|
5d48a10548 | ||
|
|
73a0b6ead5 | ||
|
|
63fdb5a7fb | ||
|
|
fdffef5940 | ||
|
|
919246fbc1 | ||
|
|
ffac902c1b | ||
|
|
15f6e57eaa | ||
|
|
729c299256 | ||
|
|
86a91e260b | ||
|
|
094fc2c6e6 | ||
|
|
2dafa730ef | ||
|
|
343ac84e5a | ||
|
|
0c967279ce | ||
|
|
efb3b2c95e | ||
|
|
7b55fe6419 | ||
|
|
e029ab34ea | ||
|
|
8cec513447 | ||
|
|
a13e45d548 |
13
.github/FUNDING.yml
vendored
Normal file
13
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# These are supported funding model platforms
|
||||||
|
|
||||||
|
github: OpenAccess-AI-Collective # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||||
|
patreon: # Replace with a single Patreon username
|
||||||
|
open_collective: # Replace with a single Open Collective username
|
||||||
|
ko_fi: # Replace with a single Ko-fi username
|
||||||
|
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||||
|
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||||
|
liberapay: # Replace with a single Liberapay username
|
||||||
|
issuehunt: # Replace with a single IssueHunt username
|
||||||
|
otechie: # Replace with a single Otechie username
|
||||||
|
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
||||||
|
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
||||||
18
README.md
18
README.md
@@ -136,7 +136,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"instruction": "...", "input": "...", "output": "..."}
|
{"instruction": "...", "input": "...", "output": "..."}
|
||||||
```
|
```
|
||||||
- `sharegpt:chat`: conversations
|
- `sharegpt:chat`: conversations where `from` is `human`/`gpt`
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
{"conversations": [{"from": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
@@ -225,6 +225,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"conversations": [{"role": "...", "value": "..."}]}
|
{"conversations": [{"role": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
|
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
|
||||||
|
```json
|
||||||
|
{"conversations": [{"from": "...", "value": "..."}]}
|
||||||
|
```
|
||||||
- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
|
- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
|
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
|
||||||
@@ -322,9 +326,9 @@ tokenizer_type: AutoTokenizer
|
|||||||
trust_remote_code:
|
trust_remote_code:
|
||||||
# use_fast option for tokenizer loading from_pretrained, default to True
|
# use_fast option for tokenizer loading from_pretrained, default to True
|
||||||
tokenizer_use_fast:
|
tokenizer_use_fast:
|
||||||
# resize the model embeddings when new tokens are added to multiples of 32
|
# resize the model embeddings when new tokens are added to multiples of N
|
||||||
# this is reported to improve training speed on some models
|
# multiples of 32 are reported to improve training speed on some models
|
||||||
resize_token_embeddings_to_32x:
|
resize_token_embeddings_multiple:
|
||||||
|
|
||||||
# whether you are training a 4-bit GPTQ quantized model
|
# whether you are training a 4-bit GPTQ quantized model
|
||||||
gptq: true
|
gptq: true
|
||||||
@@ -360,6 +364,9 @@ dataset_prepared_path: data/last_run_prepared
|
|||||||
push_dataset_to_hub: # repo path
|
push_dataset_to_hub: # repo path
|
||||||
# push checkpoints to hub
|
# push checkpoints to hub
|
||||||
hub_model_id: # repo path to push finetuned model
|
hub_model_id: # repo path to push finetuned model
|
||||||
|
# how to push checkpoints to hub
|
||||||
|
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
|
||||||
|
hub_strategy:
|
||||||
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||||
# required to be true when used in combination with `push_dataset_to_hub`
|
# required to be true when used in combination with `push_dataset_to_hub`
|
||||||
hf_use_auth_token: # boolean
|
hf_use_auth_token: # boolean
|
||||||
@@ -428,7 +435,8 @@ learning_rate: 0.00003
|
|||||||
logging_steps:
|
logging_steps:
|
||||||
save_steps:
|
save_steps:
|
||||||
eval_steps:
|
eval_steps:
|
||||||
save_total_limit:
|
save_total_limit: # checkpoints saved at a time
|
||||||
|
max_steps:
|
||||||
|
|
||||||
# save model as safetensors (require safetensors package)
|
# save model as safetensors (require safetensors package)
|
||||||
save_safetensors:
|
save_safetensors:
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
|||||||
|
|
||||||
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
||||||
cd flash-attention && \
|
cd flash-attention && \
|
||||||
git checkout v2.0.1 && \
|
git checkout v2.0.4 && \
|
||||||
python3 setup.py bdist_wheel && \
|
python3 setup.py bdist_wheel && \
|
||||||
cd csrc/fused_dense_lib && \
|
cd csrc/fused_dense_lib && \
|
||||||
python3 setup.py bdist_wheel && \
|
python3 setup.py bdist_wheel && \
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ val_set_size: 0.01
|
|||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
max_packed_sequence_len: 4096
|
sample_packing: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
@@ -49,8 +49,8 @@ early_stopping_patience:
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention: true
|
xformers_attention:
|
||||||
flash_attention:
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
eval_steps: 20
|
||||||
@@ -64,4 +64,3 @@ special_tokens:
|
|||||||
bos_token: "<s>"
|
bos_token: "<s>"
|
||||||
eos_token: "</s>"
|
eos_token: "</s>"
|
||||||
unk_token: "<unk>"
|
unk_token: "<unk>"
|
||||||
pad_token: "<pad>"
|
|
||||||
|
|||||||
@@ -18,7 +18,8 @@ adapter: qlora
|
|||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
max_packed_sequence_len: 4096
|
sample_packing: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
@@ -50,8 +51,8 @@ early_stopping_patience:
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention: true
|
xformers_attention:
|
||||||
flash_attention:
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
eval_steps: 20
|
||||||
@@ -65,4 +66,3 @@ special_tokens:
|
|||||||
bos_token: "<s>"
|
bos_token: "<s>"
|
||||||
eos_token: "</s>"
|
eos_token: "</s>"
|
||||||
unk_token: "<unk>"
|
unk_token: "<unk>"
|
||||||
pad_token: "<pad>"
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from optimum.bettertransformer import BetterTransformer
|
|||||||
from transformers import GenerationConfig, TextStreamer
|
from transformers import GenerationConfig, TextStreamer
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import barrier, is_main_process
|
from axolotl.utils.distributed import barrier, is_main_process
|
||||||
@@ -29,7 +29,6 @@ from axolotl.utils.trainer import (
|
|||||||
process_datasets_for_packing,
|
process_datasets_for_packing,
|
||||||
setup_trainer,
|
setup_trainer,
|
||||||
)
|
)
|
||||||
from axolotl.utils.validation import validate_config
|
|
||||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
@@ -44,27 +43,6 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
|||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
|
||||||
|
|
||||||
def choose_device(cfg):
|
|
||||||
def get_device():
|
|
||||||
try:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return f"cuda:{cfg.local_rank}"
|
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
return "mps"
|
|
||||||
|
|
||||||
raise SystemError("No CUDA/mps device found")
|
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
|
||||||
return "cpu"
|
|
||||||
|
|
||||||
cfg.device = get_device()
|
|
||||||
if cfg.device_map != "auto":
|
|
||||||
if cfg.device.startswith("cuda"):
|
|
||||||
cfg.device_map = {"": cfg.local_rank}
|
|
||||||
else:
|
|
||||||
cfg.device_map = {"": cfg.device}
|
|
||||||
|
|
||||||
|
|
||||||
def get_multi_line_input() -> Optional[str]:
|
def get_multi_line_input() -> Optional[str]:
|
||||||
print("Give me an instruction (Ctrl + D to finish): ")
|
print("Give me an instruction (Ctrl + D to finish): ")
|
||||||
instruction = ""
|
instruction = ""
|
||||||
@@ -194,36 +172,13 @@ def train(
|
|||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
# setup some derived config / hyperparams
|
normalize_config(cfg)
|
||||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
|
||||||
cfg.batch_size // cfg.micro_batch_size
|
|
||||||
)
|
|
||||||
cfg.batch_size = (
|
|
||||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
|
||||||
)
|
|
||||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
||||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
||||||
choose_device(cfg)
|
|
||||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
|
||||||
if cfg.ddp:
|
|
||||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
|
||||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
if cfg.device == "mps":
|
|
||||||
cfg.load_in_8bit = False
|
|
||||||
cfg.tf32 = False
|
|
||||||
if cfg.bf16:
|
|
||||||
cfg.fp16 = True
|
|
||||||
cfg.bf16 = False
|
|
||||||
|
|
||||||
if cfg.tf32:
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
|
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||||
LOG.info(f"loading tokenizer... {tokenizer_config}")
|
tokenizer = load_tokenizer(cfg)
|
||||||
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
||||||
@@ -254,7 +209,13 @@ def train(
|
|||||||
cfg, train_dataset, eval_dataset
|
cfg, train_dataset, eval_dataset
|
||||||
)
|
)
|
||||||
barrier()
|
barrier()
|
||||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
if cfg.max_steps:
|
||||||
|
total_num_steps = min(
|
||||||
|
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
||||||
|
)
|
||||||
|
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||||
|
else:
|
||||||
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||||
|
|
||||||
if cfg.debug or "debug" in kwargs:
|
if cfg.debug or "debug" in kwargs:
|
||||||
LOG.info("check_dataset_labels...")
|
LOG.info("check_dataset_labels...")
|
||||||
@@ -269,8 +230,6 @@ def train(
|
|||||||
LOG.info("Finished preparing dataset. Exiting...")
|
LOG.info("Finished preparing dataset. Exiting...")
|
||||||
return
|
return
|
||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
LOG.info("loading model and (optionally) peft_config...")
|
LOG.info("loading model and (optionally) peft_config...")
|
||||||
model, peft_config = load_model(cfg, tokenizer)
|
model, peft_config = load_model(cfg, tokenizer)
|
||||||
@@ -354,6 +313,7 @@ def train(
|
|||||||
|
|
||||||
if not Path(cfg.output_dir).is_dir():
|
if not Path(cfg.output_dir).is_dir():
|
||||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||||
|
tokenizer.save_pretrained(cfg.output_dir)
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with torch.backends.cuda.sdp_kernel(
|
||||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||||
@@ -371,14 +331,8 @@ def train(
|
|||||||
elif cfg.local_rank == 0:
|
elif cfg.local_rank == 0:
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
|
|
||||||
if cfg.adapter == "lora" and cfg.relora_steps:
|
|
||||||
model = model.merge_and_unload()
|
|
||||||
|
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(train)
|
fire.Fire(train)
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ def forward(
|
|||||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||||
)
|
)
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
elif position_ids.shape[0] == 1:
|
elif attention_mask.shape[0] == 1:
|
||||||
# special handling using sample packing
|
# special handling using sample packing
|
||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
|
|||||||
@@ -1,302 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
import glob
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os.path
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Sequence
|
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
import peft
|
|
||||||
import safetensors.torch as st
|
|
||||||
import torch
|
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
|
||||||
from torch.optim.optimizer import Optimizer
|
|
||||||
from transformers import (
|
|
||||||
TrainerCallback,
|
|
||||||
TrainerControl,
|
|
||||||
TrainerState,
|
|
||||||
TrainingArguments,
|
|
||||||
)
|
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.relora")
|
|
||||||
|
|
||||||
|
|
||||||
def reset_optimizer(optimizer: torch.optim.Optimizer):
|
|
||||||
for group in optimizer.param_groups:
|
|
||||||
for param in group["params"]:
|
|
||||||
param_state = optimizer.state[param]
|
|
||||||
for key in param_state:
|
|
||||||
if "qmap" in key:
|
|
||||||
continue
|
|
||||||
elif key == "step" and isinstance(param_state[key], int):
|
|
||||||
param_state[key] = 0
|
|
||||||
else:
|
|
||||||
param_state[key] = torch.zeros_like(param_state[key])
|
|
||||||
|
|
||||||
|
|
||||||
class ReLoRACallback(TrainerCallback):
|
|
||||||
def __init__(self, cfg: DictDefault):
|
|
||||||
self.relora_steps = cfg.relora_steps
|
|
||||||
self.cpu_offload = cfg.relora_cpu_offload
|
|
||||||
self.quantised = cfg.load_in_4bit or cfg.load_in_8bit
|
|
||||||
self.last_full_model = cfg.base_model
|
|
||||||
|
|
||||||
assert os.path.exists(
|
|
||||||
self.last_full_model
|
|
||||||
), "for ReLORA base_model must be a local path"
|
|
||||||
|
|
||||||
self.num_lora_restarts = 0
|
|
||||||
self.need_full_save = False
|
|
||||||
|
|
||||||
def on_step_begin(
|
|
||||||
self,
|
|
||||||
args: TrainingArguments,
|
|
||||||
state: TrainerState,
|
|
||||||
control: TrainerControl,
|
|
||||||
model: peft.LoraModel,
|
|
||||||
optimizer: torch.optim.Optimizer,
|
|
||||||
**_kwargs,
|
|
||||||
):
|
|
||||||
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
|
||||||
checkpoint_folder = os.path.join(
|
|
||||||
args.output_dir,
|
|
||||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
merge_and_save(
|
|
||||||
model,
|
|
||||||
self.last_full_model,
|
|
||||||
checkpoint_folder,
|
|
||||||
reinit=True,
|
|
||||||
quantized=self.quantised,
|
|
||||||
)
|
|
||||||
reset_optimizer(optimizer)
|
|
||||||
|
|
||||||
if self.quantised:
|
|
||||||
self.last_full_model = checkpoint_folder
|
|
||||||
self.num_lora_restarts += 1
|
|
||||||
|
|
||||||
return control
|
|
||||||
|
|
||||||
def on_save(
|
|
||||||
self,
|
|
||||||
args: TrainingArguments,
|
|
||||||
state: TrainerState,
|
|
||||||
control: TrainerControl,
|
|
||||||
model: peft.LoraModel,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
checkpoint_folder = os.path.join(
|
|
||||||
args.output_dir,
|
|
||||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
state.global_step >= self.relora_steps
|
|
||||||
and state.global_step % self.relora_steps != 0
|
|
||||||
):
|
|
||||||
if self.quantised and self.last_full_model != checkpoint_folder:
|
|
||||||
# ensure the latest full parameter save is in the latest checkpoint
|
|
||||||
# folder, so that automatic pruning of checkpoints does not remove it
|
|
||||||
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
|
|
||||||
chunks = glob.glob(
|
|
||||||
f"{self.last_full_model}/model*.safetensors"
|
|
||||||
) + glob.glob(f"{self.last_full_model}/model*.index.json")
|
|
||||||
for path in chunks:
|
|
||||||
shutil.move(path, checkpoint_folder)
|
|
||||||
self.last_full_model = checkpoint_folder
|
|
||||||
else:
|
|
||||||
model.model.save_pretrained(checkpoint_folder, save_safetensors=True)
|
|
||||||
|
|
||||||
return control
|
|
||||||
|
|
||||||
def on_log(
|
|
||||||
self,
|
|
||||||
_args: TrainingArguments,
|
|
||||||
_state: TrainerState,
|
|
||||||
control: TrainerControl,
|
|
||||||
logs: Dict[str, float],
|
|
||||||
**_kwargs,
|
|
||||||
):
|
|
||||||
logs["num_lora_restarts"] = self.num_lora_restarts
|
|
||||||
return control
|
|
||||||
|
|
||||||
|
|
||||||
class ReLoRAScheduler(LRScheduler):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
optimizer: Optimizer,
|
|
||||||
inner_schedule: LRScheduler,
|
|
||||||
relora_steps: int,
|
|
||||||
warmup_steps: int,
|
|
||||||
min_lr_scale: float = 0.001,
|
|
||||||
) -> None:
|
|
||||||
self.inner_schedule = inner_schedule
|
|
||||||
self.relora_steps = relora_steps
|
|
||||||
self.warmup_steps = warmup_steps
|
|
||||||
self.min_lr_scale = min_lr_scale
|
|
||||||
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
|
||||||
|
|
||||||
def get_lr(self) -> float:
|
|
||||||
self.inner_schedule.last_epoch = self.last_epoch
|
|
||||||
|
|
||||||
original = self.inner_schedule.get_lr()
|
|
||||||
step = self.last_epoch
|
|
||||||
if step < self.relora_steps:
|
|
||||||
scale = 1
|
|
||||||
else:
|
|
||||||
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
|
|
||||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
|
||||||
if isinstance(original, Sequence):
|
|
||||||
return [lr * scale for lr in original]
|
|
||||||
else:
|
|
||||||
return original * scale
|
|
||||||
|
|
||||||
|
|
||||||
def sharded_paths(path: str, keys: List[str]) -> Dict[str, str]:
|
|
||||||
model_name = "model.safetensors"
|
|
||||||
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
|
|
||||||
str(Path(path) / f"{model_name}.index.json")
|
|
||||||
):
|
|
||||||
model_name = "pytorch_model.bin"
|
|
||||||
|
|
||||||
index_path = str(Path(path) / f"{model_name}.index.json")
|
|
||||||
if os.path.exists(index_path):
|
|
||||||
data = json.load(open(index_path, "r"))
|
|
||||||
return data["weight_map"]
|
|
||||||
return {key + ".weight": model_name for key in keys}
|
|
||||||
|
|
||||||
|
|
||||||
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer) -> torch.Tensor:
|
|
||||||
if isinstance(layer, peft.tuners.lora.Linear8bitLt) or isinstance(
|
|
||||||
layer, peft.tuners.lora.Linear4bit
|
|
||||||
):
|
|
||||||
adapter = layer.active_adapter
|
|
||||||
return (
|
|
||||||
peft.utils.transpose(
|
|
||||||
layer.lora_B[adapter].weight @ layer.lora_A[adapter].weight,
|
|
||||||
getattr(layer, "fan_in_fan_out", False),
|
|
||||||
)
|
|
||||||
* layer.scaling[adapter]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return layer.get_delta_weight()
|
|
||||||
|
|
||||||
|
|
||||||
def merge_and_save(
|
|
||||||
model: peft.LoraModel,
|
|
||||||
model_src: str,
|
|
||||||
model_dst: str,
|
|
||||||
reinit: bool = False,
|
|
||||||
quantized: bool = False,
|
|
||||||
cpu_offload: bool = False,
|
|
||||||
):
|
|
||||||
key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
|
|
||||||
|
|
||||||
if not quantized:
|
|
||||||
for key in key_list:
|
|
||||||
try:
|
|
||||||
_parent, target, _target_name = peft.utils._get_submodules(
|
|
||||||
model.model, key
|
|
||||||
)
|
|
||||||
except AttributeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(target, peft.tuners.lora.LoraLayer):
|
|
||||||
update = target.get_delta_weight(target.active_adapter).detach()
|
|
||||||
target.weight.data += update
|
|
||||||
|
|
||||||
if reinit:
|
|
||||||
for adapter_name in target.lora_A:
|
|
||||||
target.reset_lora_parameters(adapter_name)
|
|
||||||
for adapter_name in target.lora_embedding_A:
|
|
||||||
target.reset_lora_parameters(adapter_name)
|
|
||||||
return
|
|
||||||
|
|
||||||
os.makedirs(model_dst, exist_ok=True)
|
|
||||||
shard_paths = sharded_paths(model_src, key_list)
|
|
||||||
|
|
||||||
unique_shards = list(set(shard_paths.values()))
|
|
||||||
for shard_path in unique_shards:
|
|
||||||
out_tensors = {}
|
|
||||||
if shard_path.endswith(".safetensors"):
|
|
||||||
in_tensors = st.load_file(str(Path(model_src) / shard_path))
|
|
||||||
else:
|
|
||||||
in_tensors = torch.load(Path(model_src) / shard_path)
|
|
||||||
if "state_dict" in in_tensors:
|
|
||||||
in_tensors = in_tensors["state_dict"]
|
|
||||||
|
|
||||||
for key in key_list:
|
|
||||||
if (key + ".weight") not in shard_paths or shard_paths[
|
|
||||||
key + ".weight"
|
|
||||||
] != shard_path:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
_parent, target, _target_name = peft.utils._get_submodules(
|
|
||||||
model.model, key
|
|
||||||
)
|
|
||||||
except AttributeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(target, peft.tuners.lora.LoraLayer):
|
|
||||||
orig_weight = in_tensors[key + ".weight"]
|
|
||||||
old_dev = target.weight.device
|
|
||||||
math_dev = "cpu" if cpu_offload else old_dev
|
|
||||||
|
|
||||||
update = lora_delta_weight(target).detach().to(math_dev)
|
|
||||||
new_weight = orig_weight.to(math_dev) + update
|
|
||||||
out_tensors[key + ".weight"] = new_weight
|
|
||||||
|
|
||||||
if reinit:
|
|
||||||
for adapter_name in target.lora_A:
|
|
||||||
target.reset_lora_parameters(adapter_name)
|
|
||||||
for adapter_name in target.lora_embedding_A:
|
|
||||||
target.reset_lora_parameters(adapter_name)
|
|
||||||
|
|
||||||
if isinstance(target, peft.tuners.lora.Linear4bit):
|
|
||||||
target.weight = (
|
|
||||||
bnb.nn.Params4bit(
|
|
||||||
new_weight,
|
|
||||||
requires_grad=False,
|
|
||||||
compress_statistics=target.weight.compress_statistics,
|
|
||||||
quant_type=target.weight.quant_type,
|
|
||||||
)
|
|
||||||
.cuda(None)
|
|
||||||
.to(old_dev)
|
|
||||||
)
|
|
||||||
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
|
||||||
target.weight = (
|
|
||||||
bnb.nn.Int8Params(new_weight, requires_grad=False)
|
|
||||||
.cuda(None)
|
|
||||||
.to(old_dev)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
target.weight.data = new_weight.to(old_dev)
|
|
||||||
|
|
||||||
for key in in_tensors:
|
|
||||||
if key not in out_tensors:
|
|
||||||
out_tensors[key] = in_tensors[key]
|
|
||||||
del in_tensors
|
|
||||||
|
|
||||||
out_shard_name = shard_path
|
|
||||||
if out_shard_name.startswith("pytorch_model"):
|
|
||||||
out_shard_name = (
|
|
||||||
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
|
|
||||||
+ ".safetensors"
|
|
||||||
)
|
|
||||||
|
|
||||||
shard_fn = str(Path(model_dst) / out_shard_name)
|
|
||||||
LOG.info(f"saving tensors to {shard_fn}")
|
|
||||||
st.save_file(out_tensors, shard_fn)
|
|
||||||
del out_tensors
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
if len(unique_shards) > 1:
|
|
||||||
with open(str(Path(model_dst, "model.safetensors.index.json")), "w") as fd:
|
|
||||||
json.dump({"metadata": {}, "weight_map": shard_paths}, fd)
|
|
||||||
@@ -312,7 +312,9 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
if len(source) < 2:
|
if len(source) < 2:
|
||||||
# If there isn't a back and forth conversation, ignore it
|
# If there isn't a back and forth conversation, ignore it
|
||||||
# also happens on the data splitting leaving empty conversations
|
# also happens on the data splitting leaving empty conversations
|
||||||
raise IndexError
|
raise IndexError(
|
||||||
|
f"A conversation entry has less than 2 messages :\n{source}"
|
||||||
|
)
|
||||||
|
|
||||||
conv = self._conversation.copy()
|
conv = self._conversation.copy()
|
||||||
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
||||||
|
|||||||
@@ -4,13 +4,23 @@ import pynvml
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def gpu_memory_usage(device):
|
def gpu_memory_usage(device=0):
|
||||||
|
return torch.cuda.memory_allocated(device) / 1024.0**3
|
||||||
|
|
||||||
|
|
||||||
|
def gpu_memory_usage_all(device=0):
|
||||||
|
usage = torch.cuda.memory_allocated(device) / 1024.0**3
|
||||||
|
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
|
||||||
|
smi = gpu_memory_usage_smi(device)
|
||||||
|
return usage, reserved - usage, max(0, smi - reserved)
|
||||||
|
|
||||||
|
|
||||||
|
def gpu_memory_usage_smi(device=0):
|
||||||
if isinstance(device, torch.device):
|
if isinstance(device, torch.device):
|
||||||
device = device.index
|
device = device.index
|
||||||
if isinstance(device, str) and device.startswith("cuda:"):
|
if isinstance(device, str) and device.startswith("cuda:"):
|
||||||
device = int(device[5:])
|
device = int(device[5:])
|
||||||
|
|
||||||
# NB torch.cuda.memory_usage returns zero so we use lower level api
|
|
||||||
pynvml.nvmlInit()
|
pynvml.nvmlInit()
|
||||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||||
@@ -18,6 +28,16 @@ def gpu_memory_usage(device):
|
|||||||
|
|
||||||
|
|
||||||
def log_gpu_memory_usage(log, msg, device):
|
def log_gpu_memory_usage(log, msg, device):
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return (0, 0, 0)
|
||||||
|
|
||||||
|
usage, cache, misc = gpu_memory_usage_all(device)
|
||||||
|
extras = []
|
||||||
|
if cache > 0:
|
||||||
|
extras.append(f"+{cache:.03f}GB cache")
|
||||||
|
if misc > 0:
|
||||||
|
extras.append(f"+{misc:.03f}GB misc")
|
||||||
log.info(
|
log.info(
|
||||||
f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
|
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
|
||||||
)
|
)
|
||||||
|
return usage, cache, misc
|
||||||
|
|||||||
@@ -33,9 +33,7 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
|||||||
)
|
)
|
||||||
|
|
||||||
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
||||||
kwargs["model"].save_pretrained(
|
kwargs["model"].save_pretrained(peft_model_path)
|
||||||
peft_model_path, save_safetensors=args.save_safetensors
|
|
||||||
)
|
|
||||||
|
|
||||||
return control
|
return control
|
||||||
|
|
||||||
@@ -76,10 +74,10 @@ class SaveBetterTransformerModelCallback(
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
class PrintGPUStatsCallback(
|
class GPUStatsCallback(
|
||||||
TrainerCallback
|
TrainerCallback
|
||||||
): # pylint: disable=too-few-public-methods disable=unused-argument
|
): # pylint: disable=too-few-public-methods disable=unused-argument
|
||||||
"""Callback to print GPU utilization"""
|
"""Callback to track GPU utilization"""
|
||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
@@ -92,7 +90,7 @@ class PrintGPUStatsCallback(
|
|||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if not self.logged:
|
if not self.logged and state.global_step > 1:
|
||||||
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
||||||
self.logged = True
|
self.logged = True
|
||||||
return control
|
return control
|
||||||
|
|||||||
@@ -1,12 +1,70 @@
|
|||||||
"""Module for validating config files"""
|
"""Module for working with config dicts"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
def choose_device(cfg):
|
||||||
|
def get_device():
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return f"cuda:{cfg.local_rank}"
|
||||||
|
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
return "mps"
|
||||||
|
|
||||||
|
raise SystemError("No CUDA/mps device found")
|
||||||
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
cfg.device = get_device()
|
||||||
|
if cfg.device_map != "auto":
|
||||||
|
if cfg.device.startswith("cuda"):
|
||||||
|
cfg.device_map = {"": cfg.local_rank}
|
||||||
|
else:
|
||||||
|
cfg.device_map = {"": cfg.device}
|
||||||
|
|
||||||
|
# in `accelerate launch`, we need to not pass through any device map and let
|
||||||
|
# accelerate figure out which parts of the model to put on which gpu
|
||||||
|
accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")]
|
||||||
|
if accelerate_vars:
|
||||||
|
cfg.device_map = None
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_config(cfg):
|
||||||
|
# setup some derived config / hyperparams
|
||||||
|
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||||
|
cfg.batch_size // cfg.micro_batch_size
|
||||||
|
)
|
||||||
|
cfg.batch_size = (
|
||||||
|
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||||
|
)
|
||||||
|
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
choose_device(cfg)
|
||||||
|
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||||
|
if cfg.ddp:
|
||||||
|
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||||
|
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||||
|
|
||||||
|
if cfg.device == "mps":
|
||||||
|
cfg.load_in_8bit = False
|
||||||
|
cfg.tf32 = False
|
||||||
|
if cfg.bf16:
|
||||||
|
cfg.fp16 = True
|
||||||
|
cfg.bf16 = False
|
||||||
|
else:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
||||||
|
|
||||||
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
|
|
||||||
def validate_config(cfg):
|
def validate_config(cfg):
|
||||||
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -61,9 +119,6 @@ def validate_config(cfg):
|
|||||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||||
|
|
||||||
if cfg.relora_steps and cfg.adapter not in ("lora", "qlora"):
|
|
||||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
|
||||||
|
|
||||||
if cfg.trust_remote_code:
|
if cfg.trust_remote_code:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||||
@@ -10,3 +10,6 @@ class DictDefault(Dict):
|
|||||||
|
|
||||||
def __missing__(self, key):
|
def __missing__(self, key):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def __or__(self, other):
|
||||||
|
return DictDefault(super().__or__(other))
|
||||||
|
|||||||
@@ -32,37 +32,66 @@ if TYPE_CHECKING:
|
|||||||
from axolotl.utils.dict import DictDefault # noqa: F401
|
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(
|
def smart_tokenizer_and_embedding_resize(
|
||||||
tokenizer_config,
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
tokenizer_type,
|
model: transformers.PreTrainedModel,
|
||||||
cfg,
|
resize_token_embeddings_multiple: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
"""Resize tokenizer and embedding.
|
||||||
|
|
||||||
|
Note: This function resizes the tokenizer to accommodate additional special tokens and the
|
||||||
|
embedding matrix of the model to match the new size of the tokenizer. If any new special tokens
|
||||||
|
have been added, the function computes the average embedding values of the existing embeddings
|
||||||
|
and sets those values for the new special token embeddings. This is done separately for the input
|
||||||
|
embeddings and output embeddings of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
old_tokens = model.get_input_embeddings().weight.data.shape[0]
|
||||||
|
num_new_tokens = len(tokenizer) - old_tokens
|
||||||
|
embeddings_len = (
|
||||||
|
math.ceil(len(tokenizer) / resize_token_embeddings_multiple)
|
||||||
|
* resize_token_embeddings_multiple
|
||||||
|
if resize_token_embeddings_multiple
|
||||||
|
else len(tokenizer)
|
||||||
|
)
|
||||||
|
model.resize_token_embeddings(embeddings_len)
|
||||||
|
|
||||||
|
if num_new_tokens > 0:
|
||||||
|
input_embeddings = model.get_input_embeddings().weight.data
|
||||||
|
output_embeddings = model.get_output_embeddings().weight.data
|
||||||
|
|
||||||
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
||||||
|
dim=0, keepdim=True
|
||||||
|
)
|
||||||
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
||||||
|
dim=0, keepdim=True
|
||||||
|
)
|
||||||
|
|
||||||
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
||||||
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
||||||
|
|
||||||
|
|
||||||
|
def load_tokenizer(cfg):
|
||||||
tokenizer_kwargs = {}
|
tokenizer_kwargs = {}
|
||||||
use_fast = True # this is the default
|
use_fast = True # this is the default
|
||||||
|
|
||||||
if cfg.tokenizer_use_fast is not None:
|
if cfg.tokenizer_use_fast is not None:
|
||||||
use_fast = cfg.tokenizer_use_fast
|
use_fast = cfg.tokenizer_use_fast
|
||||||
if cfg.tokenizer_legacy is not None:
|
if cfg.tokenizer_legacy is not None:
|
||||||
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
||||||
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
||||||
if tokenizer_type:
|
|
||||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
|
||||||
tokenizer_config,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
|
||||||
use_fast=use_fast,
|
|
||||||
**tokenizer_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
tokenizer_config,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
|
||||||
use_fast=use_fast,
|
|
||||||
**tokenizer_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
tokenizer_cls = AutoTokenizer
|
||||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
if cfg.tokenizer_type:
|
||||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
|
||||||
|
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||||
|
tokenizer = tokenizer_cls.from_pretrained(
|
||||||
|
tokenizer_config,
|
||||||
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
|
use_fast=use_fast,
|
||||||
|
**tokenizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
if tokenizer.__class__.__name__ in [
|
if tokenizer.__class__.__name__ in [
|
||||||
"LlamaTokenizer",
|
"LlamaTokenizer",
|
||||||
@@ -70,6 +99,11 @@ def load_tokenizer(
|
|||||||
]:
|
]:
|
||||||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
||||||
|
|
||||||
|
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||||
|
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||||
|
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||||
|
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||||
|
|
||||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
@@ -92,7 +126,6 @@ def load_model(
|
|||||||
base_model = cfg.base_model
|
base_model = cfg.base_model
|
||||||
base_model_config = cfg.base_model_config
|
base_model_config = cfg.base_model_config
|
||||||
model_type = cfg.model_type
|
model_type = cfg.model_type
|
||||||
adapter = cfg.adapter
|
|
||||||
|
|
||||||
# TODO refactor as a kwarg
|
# TODO refactor as a kwarg
|
||||||
load_in_8bit = cfg.load_in_8bit
|
load_in_8bit = cfg.load_in_8bit
|
||||||
@@ -235,12 +268,17 @@ def load_model(
|
|||||||
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
||||||
from transformers import LlamaForCausalLM
|
from transformers import LlamaForCausalLM
|
||||||
|
|
||||||
|
config_kwargs = {}
|
||||||
|
if cfg.rope_scaling:
|
||||||
|
config_kwargs["rope_scaling"] = cfg.rope_scaling
|
||||||
config = LlamaConfig.from_pretrained(
|
config = LlamaConfig.from_pretrained(
|
||||||
base_model_config, rope_scaling=cfg.rope_scaling
|
base_model_config,
|
||||||
|
**config_kwargs,
|
||||||
)
|
)
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=config,
|
config=config,
|
||||||
|
device_map=cfg.device_map,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -275,6 +313,7 @@ def load_model(
|
|||||||
elif model_type and not cfg.trust_remote_code:
|
elif model_type and not cfg.trust_remote_code:
|
||||||
model = getattr(transformers, model_type).from_pretrained(
|
model = getattr(transformers, model_type).from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
|
device_map=cfg.device_map,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -305,6 +344,7 @@ def load_model(
|
|||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=config,
|
config=config,
|
||||||
|
device_map=cfg.device_map,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -318,6 +358,7 @@ def load_model(
|
|||||||
LOG.exception(err)
|
LOG.exception(err)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
|
device_map=cfg.device_map,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -325,17 +366,16 @@ def load_model(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings_len = (
|
smart_tokenizer_and_embedding_resize(
|
||||||
math.ceil(len(tokenizer) / 32) * 32
|
tokenizer,
|
||||||
if cfg.resize_token_embeddings_to_32x
|
model,
|
||||||
else len(tokenizer)
|
resize_token_embeddings_multiple=cfg.resize_token_embeddings_multiple,
|
||||||
)
|
)
|
||||||
model.resize_token_embeddings(embeddings_len)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model.config, "max_position_embeddings")
|
hasattr(model.config, "max_position_embeddings")
|
||||||
and model.config.max_position_embeddings
|
and model.config.max_position_embeddings
|
||||||
and cfg.sequence_len >= model.config.max_position_embeddings
|
and cfg.sequence_len > model.config.max_position_embeddings
|
||||||
):
|
):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
||||||
@@ -364,7 +404,7 @@ def load_model(
|
|||||||
if hasattr(module, "weight"):
|
if hasattr(module, "weight"):
|
||||||
module.to(torch_dtype)
|
module.to(torch_dtype)
|
||||||
|
|
||||||
model, lora_config = load_adapter(model, cfg, adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
|
|
||||||
if cfg.ddp and not load_in_8bit:
|
if cfg.ddp and not load_in_8bit:
|
||||||
model.to(f"cuda:{cfg.local_rank}")
|
model.to(f"cuda:{cfg.local_rank}")
|
||||||
@@ -381,9 +421,6 @@ def load_model(
|
|||||||
module.scales = module.scales.half()
|
module.scales = module.scales.half()
|
||||||
module.bias = module.bias.half()
|
module.bias = module.bias.half()
|
||||||
|
|
||||||
if model.device.type == "cuda":
|
|
||||||
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
torch.cuda.device_count() > 1
|
torch.cuda.device_count() > 1
|
||||||
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
||||||
@@ -406,6 +443,9 @@ def load_model(
|
|||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.transform(model)
|
model = BetterTransformer.transform(model)
|
||||||
|
|
||||||
|
if cfg.adapter is not None:
|
||||||
|
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
return model, lora_config
|
return model, lora_config
|
||||||
|
|
||||||
|
|||||||
@@ -21,9 +21,8 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
|||||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
PrintGPUStatsCallback,
|
GPUStatsCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SavePeftModelCallback,
|
SavePeftModelCallback,
|
||||||
)
|
)
|
||||||
@@ -441,6 +440,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
training_arguments_kwargs["push_to_hub"] = True
|
training_arguments_kwargs["push_to_hub"] = True
|
||||||
training_arguments_kwargs["hub_private_repo"] = True
|
training_arguments_kwargs["hub_private_repo"] = True
|
||||||
|
|
||||||
|
if cfg.hub_strategy:
|
||||||
|
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
|
||||||
|
|
||||||
if cfg.save_safetensors:
|
if cfg.save_safetensors:
|
||||||
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
||||||
|
|
||||||
@@ -449,8 +451,17 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
"sample_packing_efficiency"
|
"sample_packing_efficiency"
|
||||||
] = cfg.sample_packing_eff_est
|
] = cfg.sample_packing_eff_est
|
||||||
|
|
||||||
|
if cfg.val_set_size == 0:
|
||||||
|
evaluation_strategy = "no"
|
||||||
|
elif cfg.eval_steps < 1:
|
||||||
|
# eval every epoch
|
||||||
|
evaluation_strategy = "epoch"
|
||||||
|
else:
|
||||||
|
# eval every eval_steps steps
|
||||||
|
evaluation_strategy = "steps"
|
||||||
|
|
||||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
max_steps=total_num_steps if cfg.max_steps else -1,
|
||||||
max_seq_length=cfg.sequence_len,
|
max_seq_length=cfg.sequence_len,
|
||||||
per_device_train_batch_size=cfg.micro_batch_size,
|
per_device_train_batch_size=cfg.micro_batch_size,
|
||||||
per_device_eval_batch_size=cfg.eval_batch_size
|
per_device_eval_batch_size=cfg.eval_batch_size
|
||||||
@@ -460,7 +471,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||||
num_train_epochs=cfg.num_epochs,
|
num_train_epochs=cfg.num_epochs,
|
||||||
learning_rate=cfg.learning_rate,
|
learning_rate=cfg.learning_rate,
|
||||||
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
|
evaluation_strategy=evaluation_strategy,
|
||||||
save_strategy="steps" if cfg.save_steps else "epoch",
|
save_strategy="steps" if cfg.save_steps else "epoch",
|
||||||
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
||||||
save_steps=cfg.save_steps,
|
save_steps=cfg.save_steps,
|
||||||
@@ -556,19 +567,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
||||||
|
|
||||||
callbacks = []
|
callbacks = []
|
||||||
callbacks.append(PrintGPUStatsCallback(cfg))
|
callbacks.append(GPUStatsCallback(cfg))
|
||||||
|
|
||||||
if cfg.relora_steps:
|
|
||||||
relora_steps = int(cfg.relora_steps)
|
|
||||||
relora_warmup_steps = int(cfg.relora_warmup_steps)
|
|
||||||
callbacks.append(ReLoRACallback(cfg))
|
|
||||||
|
|
||||||
(optimizer, lr_scheduler) = trainer_kwargs["optimizers"]
|
|
||||||
trainer_kwargs["optimizers"] = (
|
|
||||||
optimizer,
|
|
||||||
ReLoRAScheduler(optimizer, lr_scheduler, relora_steps, relora_warmup_steps),
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
||||||
if cfg.early_stopping_patience:
|
if cfg.early_stopping_patience:
|
||||||
early_stop_cb = EarlyStoppingCallback(
|
early_stop_cb = EarlyStoppingCallback(
|
||||||
|
|||||||
@@ -72,6 +72,13 @@ class DictDefaultTest(unittest.TestCase):
|
|||||||
|
|
||||||
assert cfg.random_key is None, "DictDefault should return None for missing keys"
|
assert cfg.random_key is None, "DictDefault should return None for missing keys"
|
||||||
|
|
||||||
|
def test_dict_or(self):
|
||||||
|
cfg = DictDefault({}) | DictDefault({})
|
||||||
|
|
||||||
|
assert (
|
||||||
|
cfg.random_key is None
|
||||||
|
), "DictDefault should return None for missing keys after | operation"
|
||||||
|
|
||||||
def test_dict_nested_missingparentkey(self):
|
def test_dict_nested_missingparentkey(self):
|
||||||
"""
|
"""
|
||||||
Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist.
|
Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist.
|
||||||
|
|||||||
@@ -13,17 +13,22 @@ class TestTokenizers(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def test_default_use_fast(self):
|
def test_default_use_fast(self):
|
||||||
cfg = DictDefault({})
|
cfg = DictDefault(
|
||||||
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
assert "Fast" in tokenizer.__class__.__name__
|
assert "Fast" in tokenizer.__class__.__name__
|
||||||
|
|
||||||
def test_dont_use_fast(self):
|
def test_dont_use_fast(self):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
"tokenizer_use_fast": False,
|
"tokenizer_use_fast": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
assert "Fast" not in tokenizer.__class__.__name__
|
assert "Fast" not in tokenizer.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ from typing import Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.validation import validate_config
|
|
||||||
|
|
||||||
|
|
||||||
class ValidationTest(unittest.TestCase):
|
class ValidationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user