Compare commits

...

10 Commits

Author SHA1 Message Date
Wing Lian
31079cd5fd smart resize embeddings
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-08-14 23:44:15 -04:00
NanoCode012
41ecb451c2 Feat(doc): Add max_steps to readme (#389) 2023-08-15 00:34:22 +09:00
Gabriel Puliatti
3c2ad00d07 Feat(config): add max steps (#387) 2023-08-14 11:19:29 -04:00
florian peyron
5d48a10548 Added "epoch" evaluation_strategy (#388) 2023-08-14 10:59:23 -04:00
NanoCode012
73a0b6ead5 Feat(config): Add hub_strategy (#386) 2023-08-14 07:12:55 -04:00
florian peyron
63fdb5a7fb Error msg for sharegpt if conv has less than 2 msg (#379) 2023-08-14 17:40:40 +09:00
mhenrichsen
fdffef5940 new llama-2 default settings (#370)
* new default settings

* fix whitespace

* rm max packed sequence length

---------

Co-authored-by: Mads Henrichsen <mads@BrbartiendeMads.lan>
2023-08-14 17:39:09 +09:00
Wing Lian
919246fbc1 don't pass rope_scaling kwarg if it's None (#383) 2023-08-13 18:57:38 -04:00
Wing Lian
ffac902c1b bump flash-attn to 2.0.4 for the base docker image (#382) 2023-08-13 17:55:04 -04:00
Charles Goddard
15f6e57eaa Fix crash when running without CUDA 2023-08-13 13:36:40 -07:00
9 changed files with 92 additions and 24 deletions

View File

@@ -326,9 +326,9 @@ tokenizer_type: AutoTokenizer
trust_remote_code:
# use_fast option for tokenizer loading from_pretrained, default to True
tokenizer_use_fast:
# resize the model embeddings when new tokens are added to multiples of 32
# this is reported to improve training speed on some models
resize_token_embeddings_to_32x:
# resize the model embeddings when new tokens are added to multiples of N
# multiples of 32 are reported to improve training speed on some models
resize_token_embeddings_multiple:
# whether you are training a 4-bit GPTQ quantized model
gptq: true
@@ -364,6 +364,9 @@ dataset_prepared_path: data/last_run_prepared
push_dataset_to_hub: # repo path
# push checkpoints to hub
hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy:
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean
@@ -432,7 +435,8 @@ learning_rate: 0.00003
logging_steps:
save_steps:
eval_steps:
save_total_limit:
save_total_limit: # checkpoints saved at a time
max_steps:
# save model as safetensors (require safetensors package)
save_safetensors:

View File

@@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
cd flash-attention && \
git checkout v2.0.1 && \
git checkout v2.0.4 && \
python3 setup.py bdist_wheel && \
cd csrc/fused_dense_lib && \
python3 setup.py bdist_wheel && \

View File

@@ -15,7 +15,7 @@ val_set_size: 0.01
output_dir: ./lora-out
sequence_len: 4096
max_packed_sequence_len: 4096
sample_packing: true
adapter: lora
lora_model_dir:
@@ -49,8 +49,8 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
@@ -64,4 +64,3 @@ special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
pad_token: "<pad>"

View File

@@ -18,7 +18,8 @@ adapter: qlora
lora_model_dir:
sequence_len: 4096
max_packed_sequence_len: 4096
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
@@ -50,8 +51,8 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
@@ -65,4 +66,3 @@ special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
pad_token: "<pad>"

View File

@@ -209,7 +209,13 @@ def train(
cfg, train_dataset, eval_dataset
)
barrier()
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.debug or "debug" in kwargs:
LOG.info("check_dataset_labels...")

View File

@@ -312,7 +312,9 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
raise IndexError
raise IndexError(
f"A conversation entry has less than 2 messages :\n{source}"
)
conv = self._conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

View File

@@ -28,6 +28,9 @@ def gpu_memory_usage_smi(device=0):
def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available():
return (0, 0, 0)
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
if cache > 0:

View File

@@ -32,6 +32,45 @@ if TYPE_CHECKING:
from axolotl.utils.dict import DictDefault # noqa: F401
def smart_tokenizer_and_embedding_resize(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
resize_token_embeddings_multiple: Optional[int] = None,
):
"""Resize tokenizer and embedding.
Note: This function resizes the tokenizer to accommodate additional special tokens and the
embedding matrix of the model to match the new size of the tokenizer. If any new special tokens
have been added, the function computes the average embedding values of the existing embeddings
and sets those values for the new special token embeddings. This is done separately for the input
embeddings and output embeddings of the model.
"""
old_tokens = model.get_input_embeddings().weight.data.shape[0]
num_new_tokens = len(tokenizer) - old_tokens
embeddings_len = (
math.ceil(len(tokenizer) / resize_token_embeddings_multiple)
* resize_token_embeddings_multiple
if resize_token_embeddings_multiple
else len(tokenizer)
)
model.resize_token_embeddings(embeddings_len)
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def load_tokenizer(cfg):
tokenizer_kwargs = {}
use_fast = True # this is the default
@@ -229,8 +268,12 @@ def load_model(
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
from transformers import LlamaForCausalLM
config_kwargs = {}
if cfg.rope_scaling:
config_kwargs["rope_scaling"] = cfg.rope_scaling
config = LlamaConfig.from_pretrained(
base_model_config, rope_scaling=cfg.rope_scaling
base_model_config,
**config_kwargs,
)
model = LlamaForCausalLM.from_pretrained(
base_model,
@@ -323,17 +366,16 @@ def load_model(
**model_kwargs,
)
embeddings_len = (
math.ceil(len(tokenizer) / 32) * 32
if cfg.resize_token_embeddings_to_32x
else len(tokenizer)
smart_tokenizer_and_embedding_resize(
tokenizer,
model,
resize_token_embeddings_multiple=cfg.resize_token_embeddings_multiple,
)
model.resize_token_embeddings(embeddings_len)
if (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len >= model.config.max_position_embeddings
and cfg.sequence_len > model.config.max_position_embeddings
):
LOG.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"

View File

@@ -440,6 +440,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True
if cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
@@ -448,8 +451,17 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"sample_packing_efficiency"
] = cfg.sample_packing_eff_est
if cfg.val_set_size == 0:
evaluation_strategy = "no"
elif cfg.eval_steps < 1:
# eval every epoch
evaluation_strategy = "epoch"
else:
# eval every eval_steps steps
evaluation_strategy = "steps"
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len,
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
@@ -459,7 +471,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate,
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
evaluation_strategy=evaluation_strategy,
save_strategy="steps" if cfg.save_steps else "epoch",
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
save_steps=cfg.save_steps,