Compare commits
6 Commits
feature/at
...
feature/re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1afbd8af2d | ||
|
|
b4f2eea2ed | ||
|
|
bbf88b02c1 | ||
|
|
64a8e04430 | ||
|
|
c8f7213bc6 | ||
|
|
b57238ecec |
13
.github/FUNDING.yml
vendored
13
.github/FUNDING.yml
vendored
@@ -1,13 +0,0 @@
|
|||||||
# These are supported funding model platforms
|
|
||||||
|
|
||||||
github: OpenAccess-AI-Collective # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
|
||||||
patreon: # Replace with a single Patreon username
|
|
||||||
open_collective: # Replace with a single Open Collective username
|
|
||||||
ko_fi: # Replace with a single Ko-fi username
|
|
||||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
|
||||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
|
||||||
liberapay: # Replace with a single Liberapay username
|
|
||||||
issuehunt: # Replace with a single IssueHunt username
|
|
||||||
otechie: # Replace with a single Otechie username
|
|
||||||
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
|
||||||
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
|
||||||
@@ -136,7 +136,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"instruction": "...", "input": "...", "output": "..."}
|
{"instruction": "...", "input": "...", "output": "..."}
|
||||||
```
|
```
|
||||||
- `sharegpt:chat`: conversations where `from` is `human`/`gpt`
|
- `sharegpt:chat`: conversations
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
{"conversations": [{"from": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
@@ -225,10 +225,6 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"conversations": [{"role": "...", "value": "..."}]}
|
{"conversations": [{"role": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
|
|
||||||
```json
|
|
||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
|
||||||
```
|
|
||||||
- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
|
- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
|
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from optimum.bettertransformer import BetterTransformer
|
|||||||
from transformers import GenerationConfig, TextStreamer
|
from transformers import GenerationConfig, TextStreamer
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import barrier, is_main_process
|
from axolotl.utils.distributed import barrier, is_main_process
|
||||||
@@ -29,6 +29,7 @@ from axolotl.utils.trainer import (
|
|||||||
process_datasets_for_packing,
|
process_datasets_for_packing,
|
||||||
setup_trainer,
|
setup_trainer,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.validation import validate_config
|
||||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
@@ -43,6 +44,27 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
|||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
|
||||||
|
|
||||||
|
def choose_device(cfg):
|
||||||
|
def get_device():
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return f"cuda:{cfg.local_rank}"
|
||||||
|
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
return "mps"
|
||||||
|
|
||||||
|
raise SystemError("No CUDA/mps device found")
|
||||||
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
cfg.device = get_device()
|
||||||
|
if cfg.device_map != "auto":
|
||||||
|
if cfg.device.startswith("cuda"):
|
||||||
|
cfg.device_map = {"": cfg.local_rank}
|
||||||
|
else:
|
||||||
|
cfg.device_map = {"": cfg.device}
|
||||||
|
|
||||||
|
|
||||||
def get_multi_line_input() -> Optional[str]:
|
def get_multi_line_input() -> Optional[str]:
|
||||||
print("Give me an instruction (Ctrl + D to finish): ")
|
print("Give me an instruction (Ctrl + D to finish): ")
|
||||||
instruction = ""
|
instruction = ""
|
||||||
@@ -172,13 +194,36 @@ def train(
|
|||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
normalize_config(cfg)
|
# setup some derived config / hyperparams
|
||||||
|
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||||
|
cfg.batch_size // cfg.micro_batch_size
|
||||||
|
)
|
||||||
|
cfg.batch_size = (
|
||||||
|
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||||
|
)
|
||||||
|
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
choose_device(cfg)
|
||||||
|
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||||
|
if cfg.ddp:
|
||||||
|
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||||
|
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
|
if cfg.device == "mps":
|
||||||
|
cfg.load_in_8bit = False
|
||||||
|
cfg.tf32 = False
|
||||||
|
if cfg.bf16:
|
||||||
|
cfg.fp16 = True
|
||||||
|
cfg.bf16 = False
|
||||||
|
|
||||||
|
if cfg.tf32:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||||
tokenizer = load_tokenizer(cfg)
|
LOG.info(f"loading tokenizer... {tokenizer_config}")
|
||||||
|
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
||||||
@@ -224,6 +269,8 @@ def train(
|
|||||||
LOG.info("Finished preparing dataset. Exiting...")
|
LOG.info("Finished preparing dataset. Exiting...")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
LOG.info("loading model and (optionally) peft_config...")
|
LOG.info("loading model and (optionally) peft_config...")
|
||||||
model, peft_config = load_model(cfg, tokenizer)
|
model, peft_config = load_model(cfg, tokenizer)
|
||||||
@@ -307,7 +354,6 @@ def train(
|
|||||||
|
|
||||||
if not Path(cfg.output_dir).is_dir():
|
if not Path(cfg.output_dir).is_dir():
|
||||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||||
tokenizer.save_pretrained(cfg.output_dir)
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with torch.backends.cuda.sdp_kernel(
|
||||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||||
@@ -325,8 +371,14 @@ def train(
|
|||||||
elif cfg.local_rank == 0:
|
elif cfg.local_rank == 0:
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
|
|
||||||
|
if cfg.adapter == "lora" and cfg.relora_steps:
|
||||||
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
|
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(train)
|
fire.Fire(train)
|
||||||
|
|||||||
@@ -2,54 +2,26 @@
|
|||||||
|
|
||||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import transformers
|
import transformers
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||||
flash_attn_kvpacked_func,
|
|
||||||
flash_attn_varlen_kvpacked_func,
|
|
||||||
flash_attn_varlen_qkvpacked_func,
|
|
||||||
)
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from flash_attn.flash_attn_interface import (
|
|
||||||
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
|
|
||||||
)
|
|
||||||
from flash_attn.flash_attn_interface import (
|
from flash_attn.flash_attn_interface import (
|
||||||
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||||
|
|
||||||
def replace_llama_attn_with_flash_attn():
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
|
||||||
_prepare_decoder_attention_mask
|
|
||||||
)
|
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
def forward(
|
||||||
# requires the attention mask to be the same as the key_padding_mask
|
|
||||||
def _prepare_decoder_attention_mask(
|
|
||||||
self,
|
|
||||||
attention_mask,
|
|
||||||
input_shape,
|
|
||||||
inputs_embeds,
|
|
||||||
past_key_values_length,
|
|
||||||
): # pylint: disable=unused-argument
|
|
||||||
# [bsz, seq_len]
|
|
||||||
return attention_mask
|
|
||||||
|
|
||||||
|
|
||||||
def flashattn_forward(
|
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
@@ -65,294 +37,124 @@ def flashattn_forward(
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
if not hasattr(self, "pretraining_tp"):
|
query_states = (
|
||||||
self.pretraining_tp = 1
|
self.q_proj(hidden_states)
|
||||||
|
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
if self.pretraining_tp > 1:
|
.transpose(1, 2)
|
||||||
key_value_slicing = (
|
)
|
||||||
self.num_key_value_heads * self.head_dim
|
key_states = (
|
||||||
) // self.pretraining_tp
|
self.k_proj(hidden_states)
|
||||||
query_slices = self.q_proj.weight.split(
|
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
.transpose(1, 2)
|
||||||
)
|
)
|
||||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
value_states = (
|
||||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
self.v_proj(hidden_states)
|
||||||
|
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
query_states = [
|
.transpose(1, 2)
|
||||||
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
)
|
||||||
]
|
|
||||||
query_states = torch.cat(query_states, dim=-1)
|
|
||||||
|
|
||||||
key_states = [
|
|
||||||
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
|
||||||
]
|
|
||||||
key_states = torch.cat(key_states, dim=-1)
|
|
||||||
|
|
||||||
value_states = [
|
|
||||||
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
|
||||||
]
|
|
||||||
value_states = torch.cat(value_states, dim=-1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
query_states = query_states.view(
|
|
||||||
bsz, q_len, self.num_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
key_states = key_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
value_states = value_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
# [bsz, q_len, nh, hd]
|
# [bsz, q_len, nh, hd]
|
||||||
# [bsz, nh, q_len, hd]
|
# [bsz, nh, q_len, hd]
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
assert past_key_value is None, "past_key_value is not supported"
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
query_states, key_states, cos, sin, position_ids
|
query_states, key_states, cos, sin, position_ids
|
||||||
)
|
)
|
||||||
# [bsz, nh, t, hd]
|
# [bsz, nh, t, hd]
|
||||||
|
assert not output_attentions, "output_attentions is not supported"
|
||||||
|
assert not use_cache, "use_cache is not supported"
|
||||||
|
|
||||||
if past_key_value is not None:
|
# Flash attention codes from
|
||||||
# reuse k, v, self_attention
|
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
# transform the data into the format required by flash attention
|
||||||
|
qkv = torch.stack(
|
||||||
|
[query_states, key_states, value_states], dim=2
|
||||||
|
) # [bsz, nh, 3, q_len, hd]
|
||||||
|
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||||
|
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||||
|
# the attention_mask should be the same as the key_padding_mask
|
||||||
|
key_padding_mask = attention_mask
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
if key_padding_mask is None:
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
max_s = q_len
|
||||||
|
cu_q_lens = torch.arange(
|
||||||
if output_attentions:
|
0,
|
||||||
warnings.warn(
|
(bsz + 1) * q_len,
|
||||||
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
step=q_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=qkv.device,
|
||||||
)
|
)
|
||||||
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
#
|
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||||
# flash-attn v2 start
|
)
|
||||||
#
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
|
elif position_ids.shape[0] == 1:
|
||||||
if self.training:
|
|
||||||
# during training q,k,v always have same seqlen
|
|
||||||
assert key_states.shape == query_states.shape
|
|
||||||
is_causal = True
|
|
||||||
else:
|
|
||||||
# turn off FA causal mask after first inference autoregressive iteration
|
|
||||||
# only on first autoregressive step q,k,v have same seqlen
|
|
||||||
is_causal = past_key_value is not None
|
|
||||||
|
|
||||||
if self.training and attention_mask.shape[0] == 1:
|
|
||||||
# special handling using sample packing
|
# special handling using sample packing
|
||||||
qkv = torch.stack(
|
|
||||||
[query_states, key_states, value_states], dim=2
|
|
||||||
) # [bsz, nh, 3, q_len, hd]
|
|
||||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
|
||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
cu_q_lens = cu_q_lens.squeeze()
|
cu_q_lens = cu_q_lens.squeeze()
|
||||||
|
|
||||||
output = flash_attn_varlen_qkvpacked_func(
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=is_causal
|
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||||
)
|
)
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
elif query_states.shape == key_states.shape:
|
else:
|
||||||
query_states = query_states.transpose(1, 2)
|
nheads = qkv.shape[-2]
|
||||||
key_states = key_states.transpose(1, 2)
|
|
||||||
value_states = value_states.transpose(1, 2)
|
# pylint: disable=invalid-name
|
||||||
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
||||||
query_states,
|
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
||||||
key_states,
|
x_unpad = rearrange(
|
||||||
value_states,
|
x_unpad,
|
||||||
qkvpacked=True,
|
"nnz (three h d) -> nnz three h d",
|
||||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
three=3,
|
||||||
# the attention_mask should be the same as the key_padding_mask
|
h=nheads,
|
||||||
key_padding_mask=attention_mask,
|
|
||||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
|
||||||
if attention_mask is not None
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv_unpad,
|
x_unpad,
|
||||||
cu_seqlens_q,
|
cu_q_lens,
|
||||||
max_seqlen_q,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=is_causal,
|
causal=True,
|
||||||
)
|
)
|
||||||
output = output_pad_fn(output_unpad)
|
output = rearrange(
|
||||||
else:
|
pad_input(
|
||||||
query_states = query_states.transpose(1, 2)
|
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
||||||
key_states = key_states.transpose(1, 2)
|
indices,
|
||||||
value_states = value_states.transpose(1, 2)
|
bsz,
|
||||||
if attention_mask is None or attention_mask.all().item():
|
q_len,
|
||||||
output = flash_attn_kvpacked_func(
|
),
|
||||||
query_states,
|
"b s (h d) -> b s h d",
|
||||||
torch.stack([key_states, value_states], 2),
|
h=nheads,
|
||||||
causal=is_causal,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
( # pylint: disable=unbalanced-tuple-unpacking
|
|
||||||
q_unpad,
|
|
||||||
kv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
output_pad_fn,
|
|
||||||
) = generate_qkv(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
kvpacked=True,
|
|
||||||
key_padding_mask=attention_mask,
|
|
||||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
|
||||||
if attention_mask is not None
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
|
||||||
q_unpad,
|
|
||||||
kv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
0.0,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=is_causal,
|
|
||||||
)
|
|
||||||
output = output_pad_fn(output_unpad)
|
|
||||||
|
|
||||||
attn_output = output
|
|
||||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
|
||||||
raise ValueError(
|
|
||||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
|
||||||
f" {attn_output.size()}"
|
|
||||||
)
|
|
||||||
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
|
||||||
|
|
||||||
#
|
|
||||||
# flash-attn v2 end
|
|
||||||
#
|
|
||||||
|
|
||||||
if self.pretraining_tp > 1:
|
|
||||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
|
||||||
o_proj_slices = self.o_proj.weight.split(
|
|
||||||
self.hidden_size // self.pretraining_tp, dim=1
|
|
||||||
)
|
|
||||||
attn_output = sum(
|
|
||||||
F.linear(attn_output[i], o_proj_slices[i])
|
|
||||||
for i in range(self.pretraining_tp)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
|
||||||
def generate_qkv(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
query_padding_mask=None,
|
|
||||||
key_padding_mask=None,
|
|
||||||
kvpacked=False,
|
|
||||||
qkvpacked=False,
|
|
||||||
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
|
||||||
"""
|
|
||||||
Arguments:
|
|
||||||
q: (batch_size, seqlen_q, nheads, d)
|
|
||||||
k: (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
v: (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
query_padding_mask: (batch_size, seqlen), bool
|
|
||||||
key_padding_mask: (batch_size, seqlen), bool
|
|
||||||
"""
|
|
||||||
assert not (kvpacked and qkvpacked)
|
|
||||||
batch_size, seqlen_q, nheads, d = q.shape
|
|
||||||
_, seqlen_k, nheads_k, _ = k.shape
|
|
||||||
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
|
|
||||||
if query_padding_mask is not None:
|
|
||||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
|
||||||
q, query_padding_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
|
||||||
output_unpad, indices_q, batch_size, seqlen_q
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
|
||||||
cu_seqlens_q = torch.arange(
|
|
||||||
0,
|
|
||||||
(batch_size + 1) * seqlen_q,
|
|
||||||
step=seqlen_q,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=q_unpad.device,
|
|
||||||
)
|
|
||||||
max_seqlen_q = seqlen_q
|
|
||||||
|
|
||||||
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
|
||||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
|
||||||
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
|
||||||
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
|
||||||
else:
|
|
||||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
|
||||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
|
||||||
cu_seqlens_k = torch.arange(
|
|
||||||
0,
|
|
||||||
(batch_size + 1) * seqlen_k,
|
|
||||||
step=seqlen_k,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=k_unpad.device,
|
|
||||||
)
|
|
||||||
max_seqlen_k = seqlen_k
|
|
||||||
|
|
||||||
if qkvpacked:
|
|
||||||
assert nheads == nheads_k
|
|
||||||
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
|
||||||
qkv = torch.stack([q, k, v], dim=2)
|
|
||||||
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
|
||||||
|
|
||||||
if kvpacked:
|
|
||||||
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
|
||||||
kv = torch.stack([k, v], dim=2)
|
|
||||||
return (
|
|
||||||
q_unpad,
|
|
||||||
kv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
q,
|
|
||||||
kv,
|
|
||||||
output_pad_fn,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
q_unpad,
|
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
||||||
k_unpad,
|
None,
|
||||||
v_unpad,
|
None,
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
output_pad_fn,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
|
# requires the attention mask to be the same as the key_padding_mask
|
||||||
|
def _prepare_decoder_attention_mask(
|
||||||
|
self,
|
||||||
|
attention_mask,
|
||||||
|
input_shape,
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
# [bsz, seq_len]
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def replace_llama_attn_with_flash_attn():
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||||
|
_prepare_decoder_attention_mask
|
||||||
|
)
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
||||||
|
|||||||
@@ -1,140 +0,0 @@
|
|||||||
"""
|
|
||||||
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
|
|
||||||
"""
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import transformers.models.llama.modeling_llama
|
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
|
||||||
|
|
||||||
|
|
||||||
def hijack_llama_sdp_attention():
|
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
|
||||||
sdp_attention_forward
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def sdp_attention_forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
if not hasattr(self, "pretraining_tp"):
|
|
||||||
self.pretraining_tp = 1
|
|
||||||
|
|
||||||
if self.pretraining_tp > 1:
|
|
||||||
key_value_slicing = (
|
|
||||||
self.num_key_value_heads * self.head_dim
|
|
||||||
) // self.pretraining_tp
|
|
||||||
query_slices = self.q_proj.weight.split(
|
|
||||||
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
|
||||||
)
|
|
||||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
|
|
||||||
query_states = [
|
|
||||||
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
|
||||||
]
|
|
||||||
query_states = torch.cat(query_states, dim=-1)
|
|
||||||
|
|
||||||
key_states = [
|
|
||||||
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
|
||||||
]
|
|
||||||
key_states = torch.cat(key_states, dim=-1)
|
|
||||||
|
|
||||||
value_states = [
|
|
||||||
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
|
||||||
]
|
|
||||||
value_states = torch.cat(value_states, dim=-1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
query_states = query_states.view(
|
|
||||||
bsz, q_len, self.num_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
key_states = key_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
value_states = value_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
# [bsz, q_len, nh, hd]
|
|
||||||
# [bsz, nh, q_len, hd]
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
|
||||||
query_states, key_states, cos, sin, position_ids
|
|
||||||
)
|
|
||||||
# [bsz, nh, t, hd]
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
warnings.warn(
|
|
||||||
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
#
|
|
||||||
# sdp-attn start
|
|
||||||
#
|
|
||||||
|
|
||||||
with torch.backends.cuda.sdp_kernel():
|
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
attn_mask=attention_mask,
|
|
||||||
is_causal=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
||||||
raise ValueError(
|
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
|
||||||
f" {attn_output.size()}"
|
|
||||||
)
|
|
||||||
attn_output = attn_output.transpose(1, 2)
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
||||||
|
|
||||||
#
|
|
||||||
# sdp-attn end
|
|
||||||
#
|
|
||||||
|
|
||||||
if self.pretraining_tp > 1:
|
|
||||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
|
||||||
o_proj_slices = self.o_proj.weight.split(
|
|
||||||
self.hidden_size // self.pretraining_tp, dim=1
|
|
||||||
)
|
|
||||||
attn_output = sum(
|
|
||||||
F.linear(attn_output[i], o_proj_slices[i])
|
|
||||||
for i in range(self.pretraining_tp)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
|
||||||
@@ -3,13 +3,13 @@ Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-g
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import math
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import transformers.models.llama.modeling_llama
|
import transformers.models.llama.modeling_llama
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
from torch import nn
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
@@ -21,6 +21,12 @@ def hijack_llama_attention():
|
|||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
||||||
|
|
||||||
|
|
||||||
|
def hijack_llama_sdp_attention():
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||||
|
sdp_attention_forward
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def xformers_forward(
|
def xformers_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -75,15 +81,15 @@ def xformers_forward(
|
|||||||
value_states = value_states.view(
|
value_states = value_states.view(
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
# [bsz, q_len, nh, hd]
|
|
||||||
# [bsz, nh, q_len, hd]
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
||||||
query_states, key_states, cos, sin, position_ids
|
query_states, key_states, cos, sin, position_ids
|
||||||
)
|
)
|
||||||
# [bsz, nh, t, hd]
|
# [bsz, nh, t, hd]
|
||||||
@@ -96,50 +102,74 @@ def xformers_forward(
|
|||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = transformers.models.llama.modeling_llama.repeat_kv(
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
key_states, self.num_key_value_groups
|
||||||
|
)
|
||||||
|
value_states = transformers.models.llama.modeling_llama.repeat_kv(
|
||||||
|
value_states, self.num_key_value_groups
|
||||||
|
)
|
||||||
|
|
||||||
if output_attentions:
|
# We only apply xformers optimizations if we don't need to output the whole attention matrix
|
||||||
warnings.warn(
|
if not output_attentions:
|
||||||
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
query_states = query_states.transpose(1, 2)
|
||||||
)
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
#
|
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
||||||
# xformers-attn start
|
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
||||||
#
|
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||||
|
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||||
query_states = query_states.transpose(1, 2)
|
attn_output = xformers.ops.memory_efficient_attention(
|
||||||
key_states = key_states.transpose(1, 2)
|
query_states, key_states, value_states, attn_bias=None
|
||||||
value_states = value_states.transpose(1, 2)
|
)
|
||||||
|
else:
|
||||||
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||||
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
attn_output = xformers.ops.memory_efficient_attention(
|
||||||
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
query_states,
|
||||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
key_states,
|
||||||
attn_output = xformers.ops.memory_efficient_attention(
|
value_states,
|
||||||
query_states, key_states, value_states, attn_bias=None
|
# attn_bias=attention_mask,
|
||||||
)
|
attn_bias=xformers.ops.LowerTriangularMask(),
|
||||||
|
)
|
||||||
|
attn_weights = None
|
||||||
else:
|
else:
|
||||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
attn_weights = torch.matmul(
|
||||||
attn_output = xformers.ops.memory_efficient_attention(
|
query_states, key_states.transpose(2, 3)
|
||||||
query_states,
|
) / math.sqrt(self.head_dim)
|
||||||
key_states,
|
|
||||||
value_states,
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
# attn_bias=attention_mask,
|
raise ValueError(
|
||||||
attn_bias=xformers.ops.LowerTriangularMask(),
|
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
)
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
attn_weights = torch.max(
|
||||||
|
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||||
|
)
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(
|
||||||
|
attn_weights, dim=-1, dtype=torch.float32
|
||||||
|
).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
# end x-formers vs. not x-formers if-else block
|
||||||
|
|
||||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
|
||||||
raise ValueError(
|
|
||||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
|
||||||
f" {attn_output.size()}"
|
|
||||||
)
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
#
|
|
||||||
# xformers-attn end
|
|
||||||
#
|
|
||||||
|
|
||||||
if self.pretraining_tp > 1:
|
if self.pretraining_tp > 1:
|
||||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||||
o_proj_slices = self.o_proj.weight.split(
|
o_proj_slices = self.o_proj.weight.split(
|
||||||
@@ -152,4 +182,103 @@ def xformers_forward(
|
|||||||
else:
|
else:
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def sdp_attention_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = (
|
||||||
|
self.q_proj(hidden_states)
|
||||||
|
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
key_states = (
|
||||||
|
self.k_proj(hidden_states)
|
||||||
|
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
value_states = (
|
||||||
|
self.v_proj(hidden_states)
|
||||||
|
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
# We only apply sdp attention if we don't need to output the whole attention matrix
|
||||||
|
if not output_attentions:
|
||||||
|
with torch.backends.cuda.sdp_kernel():
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
is_causal=False,
|
||||||
|
)
|
||||||
|
attn_weights = None
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(
|
||||||
|
query_states, key_states.transpose(2, 3)
|
||||||
|
) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
attn_weights = torch.max(
|
||||||
|
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||||
|
)
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(
|
||||||
|
attn_weights, dim=-1, dtype=torch.float32
|
||||||
|
).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|||||||
302
src/axolotl/monkeypatch/relora.py
Normal file
302
src/axolotl/monkeypatch/relora.py
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
# pylint: skip-file
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os.path
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Sequence
|
||||||
|
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
import peft
|
||||||
|
import safetensors.torch as st
|
||||||
|
import torch
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
from torch.optim.optimizer import Optimizer
|
||||||
|
from transformers import (
|
||||||
|
TrainerCallback,
|
||||||
|
TrainerControl,
|
||||||
|
TrainerState,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.relora")
|
||||||
|
|
||||||
|
|
||||||
|
def reset_optimizer(optimizer: torch.optim.Optimizer):
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
for param in group["params"]:
|
||||||
|
param_state = optimizer.state[param]
|
||||||
|
for key in param_state:
|
||||||
|
if "qmap" in key:
|
||||||
|
continue
|
||||||
|
elif key == "step" and isinstance(param_state[key], int):
|
||||||
|
param_state[key] = 0
|
||||||
|
else:
|
||||||
|
param_state[key] = torch.zeros_like(param_state[key])
|
||||||
|
|
||||||
|
|
||||||
|
class ReLoRACallback(TrainerCallback):
|
||||||
|
def __init__(self, cfg: DictDefault):
|
||||||
|
self.relora_steps = cfg.relora_steps
|
||||||
|
self.cpu_offload = cfg.relora_cpu_offload
|
||||||
|
self.quantised = cfg.load_in_4bit or cfg.load_in_8bit
|
||||||
|
self.last_full_model = cfg.base_model
|
||||||
|
|
||||||
|
assert os.path.exists(
|
||||||
|
self.last_full_model
|
||||||
|
), "for ReLORA base_model must be a local path"
|
||||||
|
|
||||||
|
self.num_lora_restarts = 0
|
||||||
|
self.need_full_save = False
|
||||||
|
|
||||||
|
def on_step_begin(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
model: peft.LoraModel,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
**_kwargs,
|
||||||
|
):
|
||||||
|
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
||||||
|
checkpoint_folder = os.path.join(
|
||||||
|
args.output_dir,
|
||||||
|
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
merge_and_save(
|
||||||
|
model,
|
||||||
|
self.last_full_model,
|
||||||
|
checkpoint_folder,
|
||||||
|
reinit=True,
|
||||||
|
quantized=self.quantised,
|
||||||
|
)
|
||||||
|
reset_optimizer(optimizer)
|
||||||
|
|
||||||
|
if self.quantised:
|
||||||
|
self.last_full_model = checkpoint_folder
|
||||||
|
self.num_lora_restarts += 1
|
||||||
|
|
||||||
|
return control
|
||||||
|
|
||||||
|
def on_save(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
model: peft.LoraModel,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
checkpoint_folder = os.path.join(
|
||||||
|
args.output_dir,
|
||||||
|
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
state.global_step >= self.relora_steps
|
||||||
|
and state.global_step % self.relora_steps != 0
|
||||||
|
):
|
||||||
|
if self.quantised and self.last_full_model != checkpoint_folder:
|
||||||
|
# ensure the latest full parameter save is in the latest checkpoint
|
||||||
|
# folder, so that automatic pruning of checkpoints does not remove it
|
||||||
|
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
|
||||||
|
chunks = glob.glob(
|
||||||
|
f"{self.last_full_model}/model*.safetensors"
|
||||||
|
) + glob.glob(f"{self.last_full_model}/model*.index.json")
|
||||||
|
for path in chunks:
|
||||||
|
shutil.move(path, checkpoint_folder)
|
||||||
|
self.last_full_model = checkpoint_folder
|
||||||
|
else:
|
||||||
|
model.model.save_pretrained(checkpoint_folder, save_safetensors=True)
|
||||||
|
|
||||||
|
return control
|
||||||
|
|
||||||
|
def on_log(
|
||||||
|
self,
|
||||||
|
_args: TrainingArguments,
|
||||||
|
_state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
logs: Dict[str, float],
|
||||||
|
**_kwargs,
|
||||||
|
):
|
||||||
|
logs["num_lora_restarts"] = self.num_lora_restarts
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class ReLoRAScheduler(LRScheduler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
inner_schedule: LRScheduler,
|
||||||
|
relora_steps: int,
|
||||||
|
warmup_steps: int,
|
||||||
|
min_lr_scale: float = 0.001,
|
||||||
|
) -> None:
|
||||||
|
self.inner_schedule = inner_schedule
|
||||||
|
self.relora_steps = relora_steps
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
self.min_lr_scale = min_lr_scale
|
||||||
|
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
||||||
|
|
||||||
|
def get_lr(self) -> float:
|
||||||
|
self.inner_schedule.last_epoch = self.last_epoch
|
||||||
|
|
||||||
|
original = self.inner_schedule.get_lr()
|
||||||
|
step = self.last_epoch
|
||||||
|
if step < self.relora_steps:
|
||||||
|
scale = 1
|
||||||
|
else:
|
||||||
|
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
|
||||||
|
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||||
|
if isinstance(original, Sequence):
|
||||||
|
return [lr * scale for lr in original]
|
||||||
|
else:
|
||||||
|
return original * scale
|
||||||
|
|
||||||
|
|
||||||
|
def sharded_paths(path: str, keys: List[str]) -> Dict[str, str]:
|
||||||
|
model_name = "model.safetensors"
|
||||||
|
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
|
||||||
|
str(Path(path) / f"{model_name}.index.json")
|
||||||
|
):
|
||||||
|
model_name = "pytorch_model.bin"
|
||||||
|
|
||||||
|
index_path = str(Path(path) / f"{model_name}.index.json")
|
||||||
|
if os.path.exists(index_path):
|
||||||
|
data = json.load(open(index_path, "r"))
|
||||||
|
return data["weight_map"]
|
||||||
|
return {key + ".weight": model_name for key in keys}
|
||||||
|
|
||||||
|
|
||||||
|
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer) -> torch.Tensor:
|
||||||
|
if isinstance(layer, peft.tuners.lora.Linear8bitLt) or isinstance(
|
||||||
|
layer, peft.tuners.lora.Linear4bit
|
||||||
|
):
|
||||||
|
adapter = layer.active_adapter
|
||||||
|
return (
|
||||||
|
peft.utils.transpose(
|
||||||
|
layer.lora_B[adapter].weight @ layer.lora_A[adapter].weight,
|
||||||
|
getattr(layer, "fan_in_fan_out", False),
|
||||||
|
)
|
||||||
|
* layer.scaling[adapter]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return layer.get_delta_weight()
|
||||||
|
|
||||||
|
|
||||||
|
def merge_and_save(
|
||||||
|
model: peft.LoraModel,
|
||||||
|
model_src: str,
|
||||||
|
model_dst: str,
|
||||||
|
reinit: bool = False,
|
||||||
|
quantized: bool = False,
|
||||||
|
cpu_offload: bool = False,
|
||||||
|
):
|
||||||
|
key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
|
||||||
|
|
||||||
|
if not quantized:
|
||||||
|
for key in key_list:
|
||||||
|
try:
|
||||||
|
_parent, target, _target_name = peft.utils._get_submodules(
|
||||||
|
model.model, key
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(target, peft.tuners.lora.LoraLayer):
|
||||||
|
update = target.get_delta_weight(target.active_adapter).detach()
|
||||||
|
target.weight.data += update
|
||||||
|
|
||||||
|
if reinit:
|
||||||
|
for adapter_name in target.lora_A:
|
||||||
|
target.reset_lora_parameters(adapter_name)
|
||||||
|
for adapter_name in target.lora_embedding_A:
|
||||||
|
target.reset_lora_parameters(adapter_name)
|
||||||
|
return
|
||||||
|
|
||||||
|
os.makedirs(model_dst, exist_ok=True)
|
||||||
|
shard_paths = sharded_paths(model_src, key_list)
|
||||||
|
|
||||||
|
unique_shards = list(set(shard_paths.values()))
|
||||||
|
for shard_path in unique_shards:
|
||||||
|
out_tensors = {}
|
||||||
|
if shard_path.endswith(".safetensors"):
|
||||||
|
in_tensors = st.load_file(str(Path(model_src) / shard_path))
|
||||||
|
else:
|
||||||
|
in_tensors = torch.load(Path(model_src) / shard_path)
|
||||||
|
if "state_dict" in in_tensors:
|
||||||
|
in_tensors = in_tensors["state_dict"]
|
||||||
|
|
||||||
|
for key in key_list:
|
||||||
|
if (key + ".weight") not in shard_paths or shard_paths[
|
||||||
|
key + ".weight"
|
||||||
|
] != shard_path:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
_parent, target, _target_name = peft.utils._get_submodules(
|
||||||
|
model.model, key
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(target, peft.tuners.lora.LoraLayer):
|
||||||
|
orig_weight = in_tensors[key + ".weight"]
|
||||||
|
old_dev = target.weight.device
|
||||||
|
math_dev = "cpu" if cpu_offload else old_dev
|
||||||
|
|
||||||
|
update = lora_delta_weight(target).detach().to(math_dev)
|
||||||
|
new_weight = orig_weight.to(math_dev) + update
|
||||||
|
out_tensors[key + ".weight"] = new_weight
|
||||||
|
|
||||||
|
if reinit:
|
||||||
|
for adapter_name in target.lora_A:
|
||||||
|
target.reset_lora_parameters(adapter_name)
|
||||||
|
for adapter_name in target.lora_embedding_A:
|
||||||
|
target.reset_lora_parameters(adapter_name)
|
||||||
|
|
||||||
|
if isinstance(target, peft.tuners.lora.Linear4bit):
|
||||||
|
target.weight = (
|
||||||
|
bnb.nn.Params4bit(
|
||||||
|
new_weight,
|
||||||
|
requires_grad=False,
|
||||||
|
compress_statistics=target.weight.compress_statistics,
|
||||||
|
quant_type=target.weight.quant_type,
|
||||||
|
)
|
||||||
|
.cuda(None)
|
||||||
|
.to(old_dev)
|
||||||
|
)
|
||||||
|
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
||||||
|
target.weight = (
|
||||||
|
bnb.nn.Int8Params(new_weight, requires_grad=False)
|
||||||
|
.cuda(None)
|
||||||
|
.to(old_dev)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
target.weight.data = new_weight.to(old_dev)
|
||||||
|
|
||||||
|
for key in in_tensors:
|
||||||
|
if key not in out_tensors:
|
||||||
|
out_tensors[key] = in_tensors[key]
|
||||||
|
del in_tensors
|
||||||
|
|
||||||
|
out_shard_name = shard_path
|
||||||
|
if out_shard_name.startswith("pytorch_model"):
|
||||||
|
out_shard_name = (
|
||||||
|
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
|
||||||
|
+ ".safetensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
shard_fn = str(Path(model_dst) / out_shard_name)
|
||||||
|
LOG.info(f"saving tensors to {shard_fn}")
|
||||||
|
st.save_file(out_tensors, shard_fn)
|
||||||
|
del out_tensors
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if len(unique_shards) > 1:
|
||||||
|
with open(str(Path(model_dst, "model.safetensors.index.json")), "w") as fd:
|
||||||
|
json.dump({"metadata": {}, "weight_map": shard_paths}, fd)
|
||||||
@@ -4,23 +4,13 @@ import pynvml
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def gpu_memory_usage(device=0):
|
def gpu_memory_usage(device):
|
||||||
return torch.cuda.memory_allocated(device) / 1024.0**3
|
|
||||||
|
|
||||||
|
|
||||||
def gpu_memory_usage_all(device=0):
|
|
||||||
usage = torch.cuda.memory_allocated(device) / 1024.0**3
|
|
||||||
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
|
|
||||||
smi = gpu_memory_usage_smi(device)
|
|
||||||
return usage, reserved - usage, max(0, smi - reserved)
|
|
||||||
|
|
||||||
|
|
||||||
def gpu_memory_usage_smi(device=0):
|
|
||||||
if isinstance(device, torch.device):
|
if isinstance(device, torch.device):
|
||||||
device = device.index
|
device = device.index
|
||||||
if isinstance(device, str) and device.startswith("cuda:"):
|
if isinstance(device, str) and device.startswith("cuda:"):
|
||||||
device = int(device[5:])
|
device = int(device[5:])
|
||||||
|
|
||||||
|
# NB torch.cuda.memory_usage returns zero so we use lower level api
|
||||||
pynvml.nvmlInit()
|
pynvml.nvmlInit()
|
||||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||||
@@ -28,13 +18,6 @@ def gpu_memory_usage_smi(device=0):
|
|||||||
|
|
||||||
|
|
||||||
def log_gpu_memory_usage(log, msg, device):
|
def log_gpu_memory_usage(log, msg, device):
|
||||||
usage, cache, misc = gpu_memory_usage_all(device)
|
|
||||||
extras = []
|
|
||||||
if cache > 0:
|
|
||||||
extras.append(f"+{cache:.03f}GB cache")
|
|
||||||
if misc > 0:
|
|
||||||
extras.append(f"+{misc:.03f}GB misc")
|
|
||||||
log.info(
|
log.info(
|
||||||
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
|
f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
|
||||||
)
|
)
|
||||||
return usage, cache, misc
|
|
||||||
|
|||||||
@@ -33,7 +33,9 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
|||||||
)
|
)
|
||||||
|
|
||||||
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
||||||
kwargs["model"].save_pretrained(peft_model_path)
|
kwargs["model"].save_pretrained(
|
||||||
|
peft_model_path, save_safetensors=args.save_safetensors
|
||||||
|
)
|
||||||
|
|
||||||
return control
|
return control
|
||||||
|
|
||||||
@@ -74,10 +76,10 @@ class SaveBetterTransformerModelCallback(
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
class GPUStatsCallback(
|
class PrintGPUStatsCallback(
|
||||||
TrainerCallback
|
TrainerCallback
|
||||||
): # pylint: disable=too-few-public-methods disable=unused-argument
|
): # pylint: disable=too-few-public-methods disable=unused-argument
|
||||||
"""Callback to track GPU utilization"""
|
"""Callback to print GPU utilization"""
|
||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
@@ -90,7 +92,7 @@ class GPUStatsCallback(
|
|||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if not self.logged and state.global_step > 1:
|
if not self.logged:
|
||||||
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
||||||
self.logged = True
|
self.logged = True
|
||||||
return control
|
return control
|
||||||
|
|||||||
@@ -10,6 +10,3 @@ class DictDefault(Dict):
|
|||||||
|
|
||||||
def __missing__(self, key):
|
def __missing__(self, key):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __or__(self, other):
|
|
||||||
return DictDefault(super().__or__(other))
|
|
||||||
|
|||||||
@@ -32,27 +32,37 @@ if TYPE_CHECKING:
|
|||||||
from axolotl.utils.dict import DictDefault # noqa: F401
|
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(cfg):
|
def load_tokenizer(
|
||||||
|
tokenizer_config,
|
||||||
|
tokenizer_type,
|
||||||
|
cfg,
|
||||||
|
):
|
||||||
tokenizer_kwargs = {}
|
tokenizer_kwargs = {}
|
||||||
use_fast = True # this is the default
|
use_fast = True # this is the default
|
||||||
|
|
||||||
if cfg.tokenizer_use_fast is not None:
|
if cfg.tokenizer_use_fast is not None:
|
||||||
use_fast = cfg.tokenizer_use_fast
|
use_fast = cfg.tokenizer_use_fast
|
||||||
if cfg.tokenizer_legacy is not None:
|
if cfg.tokenizer_legacy is not None:
|
||||||
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
||||||
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
||||||
|
if tokenizer_type:
|
||||||
|
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
||||||
|
tokenizer_config,
|
||||||
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
|
use_fast=use_fast,
|
||||||
|
**tokenizer_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
tokenizer_config,
|
||||||
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
|
use_fast=use_fast,
|
||||||
|
**tokenizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
tokenizer_cls = AutoTokenizer
|
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||||
if cfg.tokenizer_type:
|
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||||
|
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
|
||||||
tokenizer = tokenizer_cls.from_pretrained(
|
|
||||||
tokenizer_config,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
|
||||||
use_fast=use_fast,
|
|
||||||
**tokenizer_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if tokenizer.__class__.__name__ in [
|
if tokenizer.__class__.__name__ in [
|
||||||
"LlamaTokenizer",
|
"LlamaTokenizer",
|
||||||
@@ -60,11 +70,6 @@ def load_tokenizer(cfg):
|
|||||||
]:
|
]:
|
||||||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
||||||
|
|
||||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
|
||||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
|
||||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
|
||||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
|
||||||
|
|
||||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
@@ -87,6 +92,7 @@ def load_model(
|
|||||||
base_model = cfg.base_model
|
base_model = cfg.base_model
|
||||||
base_model_config = cfg.base_model_config
|
base_model_config = cfg.base_model_config
|
||||||
model_type = cfg.model_type
|
model_type = cfg.model_type
|
||||||
|
adapter = cfg.adapter
|
||||||
|
|
||||||
# TODO refactor as a kwarg
|
# TODO refactor as a kwarg
|
||||||
load_in_8bit = cfg.load_in_8bit
|
load_in_8bit = cfg.load_in_8bit
|
||||||
@@ -112,7 +118,9 @@ def load_model(
|
|||||||
LOG.info("patching with xformers attention")
|
LOG.info("patching with xformers attention")
|
||||||
hijack_llama_attention()
|
hijack_llama_attention()
|
||||||
elif cfg.is_llama_derived_model and cfg.sdp_attention:
|
elif cfg.is_llama_derived_model and cfg.sdp_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention
|
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||||
|
hijack_llama_sdp_attention,
|
||||||
|
)
|
||||||
|
|
||||||
LOG.info("patching with sdp attention")
|
LOG.info("patching with sdp attention")
|
||||||
hijack_llama_sdp_attention()
|
hijack_llama_sdp_attention()
|
||||||
@@ -233,7 +241,6 @@ def load_model(
|
|||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=config,
|
config=config,
|
||||||
device_map=cfg.device_map,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -268,7 +275,6 @@ def load_model(
|
|||||||
elif model_type and not cfg.trust_remote_code:
|
elif model_type and not cfg.trust_remote_code:
|
||||||
model = getattr(transformers, model_type).from_pretrained(
|
model = getattr(transformers, model_type).from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
device_map=cfg.device_map,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -299,7 +305,6 @@ def load_model(
|
|||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=config,
|
config=config,
|
||||||
device_map=cfg.device_map,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -313,7 +318,6 @@ def load_model(
|
|||||||
LOG.exception(err)
|
LOG.exception(err)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
device_map=cfg.device_map,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -360,7 +364,7 @@ def load_model(
|
|||||||
if hasattr(module, "weight"):
|
if hasattr(module, "weight"):
|
||||||
module.to(torch_dtype)
|
module.to(torch_dtype)
|
||||||
|
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, adapter)
|
||||||
|
|
||||||
if cfg.ddp and not load_in_8bit:
|
if cfg.ddp and not load_in_8bit:
|
||||||
model.to(f"cuda:{cfg.local_rank}")
|
model.to(f"cuda:{cfg.local_rank}")
|
||||||
@@ -377,6 +381,9 @@ def load_model(
|
|||||||
module.scales = module.scales.half()
|
module.scales = module.scales.half()
|
||||||
module.bias = module.bias.half()
|
module.bias = module.bias.half()
|
||||||
|
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
torch.cuda.device_count() > 1
|
torch.cuda.device_count() > 1
|
||||||
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
||||||
@@ -399,9 +406,6 @@ def load_model(
|
|||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.transform(model)
|
model = BetterTransformer.transform(model)
|
||||||
|
|
||||||
if cfg.adapter is not None:
|
|
||||||
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
return model, lora_config
|
return model, lora_config
|
||||||
|
|
||||||
|
|||||||
@@ -21,8 +21,9 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
|||||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
GPUStatsCallback,
|
PrintGPUStatsCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SavePeftModelCallback,
|
SavePeftModelCallback,
|
||||||
)
|
)
|
||||||
@@ -555,7 +556,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
||||||
|
|
||||||
callbacks = []
|
callbacks = []
|
||||||
callbacks.append(GPUStatsCallback(cfg))
|
callbacks.append(PrintGPUStatsCallback(cfg))
|
||||||
|
|
||||||
|
if cfg.relora_steps:
|
||||||
|
relora_steps = int(cfg.relora_steps)
|
||||||
|
relora_warmup_steps = int(cfg.relora_warmup_steps)
|
||||||
|
callbacks.append(ReLoRACallback(cfg))
|
||||||
|
|
||||||
|
(optimizer, lr_scheduler) = trainer_kwargs["optimizers"]
|
||||||
|
trainer_kwargs["optimizers"] = (
|
||||||
|
optimizer,
|
||||||
|
ReLoRAScheduler(optimizer, lr_scheduler, relora_steps, relora_warmup_steps),
|
||||||
|
)
|
||||||
|
|
||||||
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
||||||
if cfg.early_stopping_patience:
|
if cfg.early_stopping_patience:
|
||||||
early_stop_cb = EarlyStoppingCallback(
|
early_stop_cb = EarlyStoppingCallback(
|
||||||
|
|||||||
@@ -1,70 +1,12 @@
|
|||||||
"""Module for working with config dicts"""
|
"""Module for validating config files"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def choose_device(cfg):
|
|
||||||
def get_device():
|
|
||||||
try:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return f"cuda:{cfg.local_rank}"
|
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
return "mps"
|
|
||||||
|
|
||||||
raise SystemError("No CUDA/mps device found")
|
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
|
||||||
return "cpu"
|
|
||||||
|
|
||||||
cfg.device = get_device()
|
|
||||||
if cfg.device_map != "auto":
|
|
||||||
if cfg.device.startswith("cuda"):
|
|
||||||
cfg.device_map = {"": cfg.local_rank}
|
|
||||||
else:
|
|
||||||
cfg.device_map = {"": cfg.device}
|
|
||||||
|
|
||||||
# in `accelerate launch`, we need to not pass through any device map and let
|
|
||||||
# accelerate figure out which parts of the model to put on which gpu
|
|
||||||
accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")]
|
|
||||||
if accelerate_vars:
|
|
||||||
cfg.device_map = None
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_config(cfg):
|
|
||||||
# setup some derived config / hyperparams
|
|
||||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
|
||||||
cfg.batch_size // cfg.micro_batch_size
|
|
||||||
)
|
|
||||||
cfg.batch_size = (
|
|
||||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
|
||||||
)
|
|
||||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
||||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
||||||
choose_device(cfg)
|
|
||||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
|
||||||
if cfg.ddp:
|
|
||||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
|
||||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
|
||||||
|
|
||||||
if cfg.device == "mps":
|
|
||||||
cfg.load_in_8bit = False
|
|
||||||
cfg.tf32 = False
|
|
||||||
if cfg.bf16:
|
|
||||||
cfg.fp16 = True
|
|
||||||
cfg.bf16 = False
|
|
||||||
else:
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
|
||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_config(cfg):
|
def validate_config(cfg):
|
||||||
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -119,6 +61,9 @@ def validate_config(cfg):
|
|||||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||||
|
|
||||||
|
if cfg.relora_steps and cfg.adapter not in ("lora", "qlora"):
|
||||||
|
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||||
|
|
||||||
if cfg.trust_remote_code:
|
if cfg.trust_remote_code:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||||
@@ -72,13 +72,6 @@ class DictDefaultTest(unittest.TestCase):
|
|||||||
|
|
||||||
assert cfg.random_key is None, "DictDefault should return None for missing keys"
|
assert cfg.random_key is None, "DictDefault should return None for missing keys"
|
||||||
|
|
||||||
def test_dict_or(self):
|
|
||||||
cfg = DictDefault({}) | DictDefault({})
|
|
||||||
|
|
||||||
assert (
|
|
||||||
cfg.random_key is None
|
|
||||||
), "DictDefault should return None for missing keys after | operation"
|
|
||||||
|
|
||||||
def test_dict_nested_missingparentkey(self):
|
def test_dict_nested_missingparentkey(self):
|
||||||
"""
|
"""
|
||||||
Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist.
|
Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist.
|
||||||
|
|||||||
@@ -13,22 +13,17 @@ class TestTokenizers(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def test_default_use_fast(self):
|
def test_default_use_fast(self):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault({})
|
||||||
{
|
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
|
||||||
assert "Fast" in tokenizer.__class__.__name__
|
assert "Fast" in tokenizer.__class__.__name__
|
||||||
|
|
||||||
def test_dont_use_fast(self):
|
def test_dont_use_fast(self):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
|
||||||
"tokenizer_use_fast": False,
|
"tokenizer_use_fast": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
||||||
assert "Fast" not in tokenizer.__class__.__name__
|
assert "Fast" not in tokenizer.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ from typing import Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.utils.config import validate_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.validation import validate_config
|
||||||
|
|
||||||
|
|
||||||
class ValidationTest(unittest.TestCase):
|
class ValidationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user