Compare commits

...

21 Commits

Author SHA1 Message Date
Wing Lian
31079cd5fd smart resize embeddings
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-08-14 23:44:15 -04:00
NanoCode012
41ecb451c2 Feat(doc): Add max_steps to readme (#389) 2023-08-15 00:34:22 +09:00
Gabriel Puliatti
3c2ad00d07 Feat(config): add max steps (#387) 2023-08-14 11:19:29 -04:00
florian peyron
5d48a10548 Added "epoch" evaluation_strategy (#388) 2023-08-14 10:59:23 -04:00
NanoCode012
73a0b6ead5 Feat(config): Add hub_strategy (#386) 2023-08-14 07:12:55 -04:00
florian peyron
63fdb5a7fb Error msg for sharegpt if conv has less than 2 msg (#379) 2023-08-14 17:40:40 +09:00
mhenrichsen
fdffef5940 new llama-2 default settings (#370)
* new default settings

* fix whitespace

* rm max packed sequence length

---------

Co-authored-by: Mads Henrichsen <mads@BrbartiendeMads.lan>
2023-08-14 17:39:09 +09:00
Wing Lian
919246fbc1 don't pass rope_scaling kwarg if it's None (#383) 2023-08-13 18:57:38 -04:00
Wing Lian
ffac902c1b bump flash-attn to 2.0.4 for the base docker image (#382) 2023-08-13 17:55:04 -04:00
Charles Goddard
15f6e57eaa Fix crash when running without CUDA 2023-08-13 13:36:40 -07:00
NanoCode012
729c299256 Feat(doc): Improve sharegpt doc (#378)
* Feat(doc): Improve sharegpt doc

* Fix typo
2023-08-14 00:36:00 +09:00
Wing Lian
86a91e260b save tokenizer before training starts (#380) 2023-08-13 11:28:58 -04:00
Aman Gupta Karmani
094fc2c6e6 try to detect accelerate and only use device_map=None in that case (#373) 2023-08-13 00:32:07 -04:00
Wing Lian
2dafa730ef Create FUNDING.yml 2023-08-13 00:30:34 -04:00
Wing Lian
343ac84e5a fix check for flash attn branching (#377) 2023-08-12 22:48:08 -04:00
Aman Karmani
0c967279ce remove unnecessary local variable 2023-08-13 01:58:39 +00:00
Aman Karmani
efb3b2c95e simplify load_tokenizer 2023-08-12 18:55:06 -07:00
Aman Karmani
7b55fe6419 improve GPU logging to break out pytorch cache and system mem 2023-08-12 18:52:57 -07:00
Aman Karmani
e029ab34ea quiet noise from llama tokenizer by setting pad token earlier 2023-08-12 18:31:40 -07:00
Aman Karmani
8cec513447 extract module for working with cfg 2023-08-12 18:25:27 -07:00
Aman Karmani
a13e45d548 fix DefaultDict.__or__ 2023-08-13 01:15:50 +00:00
17 changed files with 244 additions and 117 deletions

13
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1,13 @@
# These are supported funding model platforms
github: OpenAccess-AI-Collective # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

View File

@@ -136,7 +136,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"instruction": "...", "input": "...", "output": "..."}
```
- `sharegpt:chat`: conversations
- `sharegpt:chat`: conversations where `from` is `human`/`gpt`
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
@@ -225,6 +225,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
```json
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
@@ -322,9 +326,9 @@ tokenizer_type: AutoTokenizer
trust_remote_code:
# use_fast option for tokenizer loading from_pretrained, default to True
tokenizer_use_fast:
# resize the model embeddings when new tokens are added to multiples of 32
# this is reported to improve training speed on some models
resize_token_embeddings_to_32x:
# resize the model embeddings when new tokens are added to multiples of N
# multiples of 32 are reported to improve training speed on some models
resize_token_embeddings_multiple:
# whether you are training a 4-bit GPTQ quantized model
gptq: true
@@ -360,6 +364,9 @@ dataset_prepared_path: data/last_run_prepared
push_dataset_to_hub: # repo path
# push checkpoints to hub
hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy:
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean
@@ -428,7 +435,8 @@ learning_rate: 0.00003
logging_steps:
save_steps:
eval_steps:
save_total_limit:
save_total_limit: # checkpoints saved at a time
max_steps:
# save model as safetensors (require safetensors package)
save_safetensors:

View File

@@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
cd flash-attention && \
git checkout v2.0.1 && \
git checkout v2.0.4 && \
python3 setup.py bdist_wheel && \
cd csrc/fused_dense_lib && \
python3 setup.py bdist_wheel && \

View File

@@ -15,7 +15,7 @@ val_set_size: 0.01
output_dir: ./lora-out
sequence_len: 4096
max_packed_sequence_len: 4096
sample_packing: true
adapter: lora
lora_model_dir:
@@ -49,8 +49,8 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
@@ -64,4 +64,3 @@ special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
pad_token: "<pad>"

View File

@@ -18,7 +18,8 @@ adapter: qlora
lora_model_dir:
sequence_len: 4096
max_packed_sequence_len: 4096
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
@@ -50,8 +51,8 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
@@ -65,4 +66,3 @@ special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
pad_token: "<pad>"

View File

@@ -18,7 +18,7 @@ from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
from axolotl.logging_config import configure_logging
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import barrier, is_main_process
@@ -29,7 +29,6 @@ from axolotl.utils.trainer import (
process_datasets_for_packing,
setup_trainer,
)
from axolotl.utils.validation import validate_config
from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -44,27 +43,6 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def choose_device(cfg):
def get_device():
try:
if torch.cuda.is_available():
return f"cuda:{cfg.local_rank}"
if torch.backends.mps.is_available():
return "mps"
raise SystemError("No CUDA/mps device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
cfg.device = get_device()
if cfg.device_map != "auto":
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to finish): ")
instruction = ""
@@ -194,36 +172,13 @@ def train(
validate_config(cfg)
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
normalize_config(cfg)
setup_wandb_env_vars(cfg)
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
if cfg.tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# load the tokenizer first
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
LOG.info(f"loading tokenizer... {tokenizer_config}")
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
if (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
@@ -254,7 +209,13 @@ def train(
cfg, train_dataset, eval_dataset
)
barrier()
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.debug or "debug" in kwargs:
LOG.info("check_dataset_labels...")
@@ -269,8 +230,6 @@ def train(
LOG.info("Finished preparing dataset. Exiting...")
return
log_gpu_memory_usage(LOG, "baseline", cfg.device)
# Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...")
model, peft_config = load_model(cfg, tokenizer)
@@ -354,6 +313,7 @@ def train(
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(cfg.output_dir)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True

View File

@@ -92,7 +92,7 @@ def forward(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif position_ids.shape[0] == 1:
elif attention_mask.shape[0] == 1:
# special handling using sample packing
qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)

View File

@@ -312,7 +312,9 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
raise IndexError
raise IndexError(
f"A conversation entry has less than 2 messages :\n{source}"
)
conv = self._conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

View File

@@ -4,13 +4,23 @@ import pynvml
import torch
def gpu_memory_usage(device):
def gpu_memory_usage(device=0):
return torch.cuda.memory_allocated(device) / 1024.0**3
def gpu_memory_usage_all(device=0):
usage = torch.cuda.memory_allocated(device) / 1024.0**3
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
smi = gpu_memory_usage_smi(device)
return usage, reserved - usage, max(0, smi - reserved)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
device = device.index
if isinstance(device, str) and device.startswith("cuda:"):
device = int(device[5:])
# NB torch.cuda.memory_usage returns zero so we use lower level api
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
@@ -18,6 +28,16 @@ def gpu_memory_usage(device):
def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available():
return (0, 0, 0)
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
if cache > 0:
extras.append(f"+{cache:.03f}GB cache")
if misc > 0:
extras.append(f"+{misc:.03f}GB misc")
log.info(
f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
)
return usage, cache, misc

View File

@@ -74,10 +74,10 @@ class SaveBetterTransformerModelCallback(
return control
class PrintGPUStatsCallback(
class GPUStatsCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods disable=unused-argument
"""Callback to print GPU utilization"""
"""Callback to track GPU utilization"""
def __init__(self, cfg):
self.cfg = cfg
@@ -90,7 +90,7 @@ class PrintGPUStatsCallback(
control: TrainerControl,
**kwargs,
):
if not self.logged:
if not self.logged and state.global_step > 1:
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
return control

View File

@@ -1,12 +1,70 @@
"""Module for validating config files"""
"""Module for working with config dicts"""
import logging
import os
import torch
from axolotl.utils.bench import log_gpu_memory_usage
LOG = logging.getLogger("axolotl")
def choose_device(cfg):
def get_device():
try:
if torch.cuda.is_available():
return f"cuda:{cfg.local_rank}"
if torch.backends.mps.is_available():
return "mps"
raise SystemError("No CUDA/mps device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
cfg.device = get_device()
if cfg.device_map != "auto":
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
# in `accelerate launch`, we need to not pass through any device map and let
# accelerate figure out which parts of the model to put on which gpu
accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")]
if accelerate_vars:
cfg.device_map = None
def normalize_config(cfg):
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
log_gpu_memory_usage(LOG, "baseline", cfg.device)
def validate_config(cfg):
if cfg.max_packed_sequence_len and cfg.sample_packing:
raise ValueError(

View File

@@ -10,3 +10,6 @@ class DictDefault(Dict):
def __missing__(self, key):
return None
def __or__(self, other):
return DictDefault(super().__or__(other))

View File

@@ -32,37 +32,66 @@ if TYPE_CHECKING:
from axolotl.utils.dict import DictDefault # noqa: F401
def load_tokenizer(
tokenizer_config,
tokenizer_type,
cfg,
def smart_tokenizer_and_embedding_resize(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
resize_token_embeddings_multiple: Optional[int] = None,
):
"""Resize tokenizer and embedding.
Note: This function resizes the tokenizer to accommodate additional special tokens and the
embedding matrix of the model to match the new size of the tokenizer. If any new special tokens
have been added, the function computes the average embedding values of the existing embeddings
and sets those values for the new special token embeddings. This is done separately for the input
embeddings and output embeddings of the model.
"""
old_tokens = model.get_input_embeddings().weight.data.shape[0]
num_new_tokens = len(tokenizer) - old_tokens
embeddings_len = (
math.ceil(len(tokenizer) / resize_token_embeddings_multiple)
* resize_token_embeddings_multiple
if resize_token_embeddings_multiple
else len(tokenizer)
)
model.resize_token_embeddings(embeddings_len)
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def load_tokenizer(cfg):
tokenizer_kwargs = {}
use_fast = True # this is the default
if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast
if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
if tokenizer_type:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
tokenizer_cls = AutoTokenizer
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
if tokenizer.__class__.__name__ in [
"LlamaTokenizer",
@@ -70,6 +99,11 @@ def load_tokenizer(
]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -92,7 +126,6 @@ def load_model(
base_model = cfg.base_model
base_model_config = cfg.base_model_config
model_type = cfg.model_type
adapter = cfg.adapter
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
@@ -235,12 +268,17 @@ def load_model(
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
from transformers import LlamaForCausalLM
config_kwargs = {}
if cfg.rope_scaling:
config_kwargs["rope_scaling"] = cfg.rope_scaling
config = LlamaConfig.from_pretrained(
base_model_config, rope_scaling=cfg.rope_scaling
base_model_config,
**config_kwargs,
)
model = LlamaForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
@@ -275,6 +313,7 @@ def load_model(
elif model_type and not cfg.trust_remote_code:
model = getattr(transformers, model_type).from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
@@ -305,6 +344,7 @@ def load_model(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
@@ -318,6 +358,7 @@ def load_model(
LOG.exception(err)
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
@@ -325,17 +366,16 @@ def load_model(
**model_kwargs,
)
embeddings_len = (
math.ceil(len(tokenizer) / 32) * 32
if cfg.resize_token_embeddings_to_32x
else len(tokenizer)
smart_tokenizer_and_embedding_resize(
tokenizer,
model,
resize_token_embeddings_multiple=cfg.resize_token_embeddings_multiple,
)
model.resize_token_embeddings(embeddings_len)
if (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len >= model.config.max_position_embeddings
and cfg.sequence_len > model.config.max_position_embeddings
):
LOG.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
@@ -364,7 +404,7 @@ def load_model(
if hasattr(module, "weight"):
module.to(torch_dtype)
model, lora_config = load_adapter(model, cfg, adapter)
model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}")
@@ -381,9 +421,6 @@ def load_model(
module.scales = module.scales.half()
module.bias = module.bias.half()
if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after adapters", model.device)
if (
torch.cuda.device_count() > 1
and int(os.getenv("WORLD_SIZE", "1")) > 1
@@ -406,6 +443,9 @@ def load_model(
if cfg.flash_optimum:
model = BetterTransformer.transform(model)
if cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", model.device)
# TODO resume_from_checkpoint handling
return model, lora_config

View File

@@ -22,7 +22,7 @@ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import (
PrintGPUStatsCallback,
GPUStatsCallback,
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
@@ -440,6 +440,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True
if cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
@@ -448,8 +451,17 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"sample_packing_efficiency"
] = cfg.sample_packing_eff_est
if cfg.val_set_size == 0:
evaluation_strategy = "no"
elif cfg.eval_steps < 1:
# eval every epoch
evaluation_strategy = "epoch"
else:
# eval every eval_steps steps
evaluation_strategy = "steps"
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len,
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
@@ -459,7 +471,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate,
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
evaluation_strategy=evaluation_strategy,
save_strategy="steps" if cfg.save_steps else "epoch",
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
save_steps=cfg.save_steps,
@@ -555,7 +567,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
callbacks = []
callbacks.append(PrintGPUStatsCallback(cfg))
callbacks.append(GPUStatsCallback(cfg))
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(

View File

@@ -72,6 +72,13 @@ class DictDefaultTest(unittest.TestCase):
assert cfg.random_key is None, "DictDefault should return None for missing keys"
def test_dict_or(self):
cfg = DictDefault({}) | DictDefault({})
assert (
cfg.random_key is None
), "DictDefault should return None for missing keys after | operation"
def test_dict_nested_missingparentkey(self):
"""
Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist.

View File

@@ -13,17 +13,22 @@ class TestTokenizers(unittest.TestCase):
"""
def test_default_use_fast(self):
cfg = DictDefault({})
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
}
)
tokenizer = load_tokenizer(cfg)
assert "Fast" in tokenizer.__class__.__name__
def test_dont_use_fast(self):
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"tokenizer_use_fast": False,
}
)
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
tokenizer = load_tokenizer(cfg)
assert "Fast" not in tokenizer.__class__.__name__

View File

@@ -6,8 +6,8 @@ from typing import Optional
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.validation import validate_config
class ValidationTest(unittest.TestCase):