Compare commits
4 Commits
feature/re
...
ssmi-main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9793faf6dc | ||
|
|
64852ae15a | ||
|
|
1fed74b1d9 | ||
|
|
a300a4db1d |
49
README.md
49
README.md
@@ -375,14 +375,7 @@ dataset_shard_idx:
|
||||
sequence_len: 2048
|
||||
# max sequence length to concatenate training samples together up to
|
||||
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||
# FutureWarning: This will soon be DEPRECATED
|
||||
max_packed_sequence_len: 1024
|
||||
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||
sample_packing:
|
||||
# you can set these packing optimizations AFTER starting a training at least once.
|
||||
# The trainer will provide recommended values for these values.
|
||||
sample_packing_eff_est:
|
||||
total_num_tokens:
|
||||
|
||||
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
adapter: lora
|
||||
@@ -408,12 +401,11 @@ lora_out_dir:
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# wandb configuration if you're using it
|
||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||
wandb_project: # your wandb project name
|
||||
wandb_entity: # a wandb Team name if using a Team
|
||||
wandb_mode:
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id: # set the name of your wandb run
|
||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||
wandb_run_id:
|
||||
wandb_log_model: # 'checkpoint'
|
||||
|
||||
# where to save the finished model to
|
||||
output_dir: ./completed-model
|
||||
@@ -428,16 +420,13 @@ learning_rate: 0.00003
|
||||
logging_steps:
|
||||
save_steps:
|
||||
eval_steps:
|
||||
save_total_limit:
|
||||
|
||||
# save model as safetensors (require safetensors package)
|
||||
save_safetensors:
|
||||
|
||||
# whether to mask out or include the human's prompt from the training labels
|
||||
train_on_inputs: false
|
||||
# group similarly sized data to minimize padding
|
||||
# may be slower to start, as it must download and sort the entire dataset
|
||||
# note that training loss may have an oscillating pattern with this enabled
|
||||
# don't use this, leads to wonky training (according to someone on the internet)
|
||||
group_by_length: false
|
||||
|
||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
@@ -483,10 +472,6 @@ landmark_attention:
|
||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
# llama only
|
||||
xpos_rope:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
# resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
@@ -518,9 +503,6 @@ torchdistx_path:
|
||||
# Set padding for data collator to 'longest'
|
||||
collator_pad_to_longest:
|
||||
|
||||
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
|
||||
pretraining_dataset:
|
||||
|
||||
# Debug mode
|
||||
debug:
|
||||
|
||||
@@ -540,14 +522,7 @@ Run
|
||||
accelerate launch scripts/finetune.py configs/your_config.yml
|
||||
```
|
||||
|
||||
#### Multi-GPU
|
||||
|
||||
You can optionally pre-tokenize dataset with the following before finetuning:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
|
||||
```
|
||||
|
||||
##### Config
|
||||
#### Multi-GPU Config
|
||||
|
||||
- llama FSDP
|
||||
```yaml
|
||||
@@ -562,18 +537,6 @@ fsdp_config:
|
||||
|
||||
- llama Deepspeed: append `ACCELERATE_USE_DEEPSPEED=true` in front of finetune command
|
||||
|
||||
##### Weights & Biases Logging
|
||||
|
||||
- wandb options
|
||||
```yaml
|
||||
wandb_mode:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
Pass the appropriate flag to the train command:
|
||||
|
||||
@@ -37,18 +37,18 @@
|
||||
"lr": "auto",
|
||||
"betas": [
|
||||
0.9,
|
||||
0.95
|
||||
0.999
|
||||
],
|
||||
"eps": 1e-8,
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"type": "OneCycle",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto"
|
||||
"cycle_min_lr": 0.00001,
|
||||
"cycle_max_lr": 0.00003,
|
||||
"cycle_first_step_size": 120
|
||||
}
|
||||
},
|
||||
"train_batch_size": "auto",
|
||||
@@ -23,7 +23,6 @@ lora_target_modules:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
@@ -36,7 +35,7 @@ torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
group_by_length: true
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
@@ -24,7 +24,6 @@ lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -38,7 +38,6 @@ lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -24,7 +24,6 @@ lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -20,7 +20,6 @@ lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
@@ -33,7 +32,7 @@ torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0001
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
group_by_length: true
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
@@ -22,7 +22,6 @@ lora_target_modules:
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: llama-7b-lora-int4
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -18,7 +18,6 @@ lora_dropout:
|
||||
lora_target_modules:
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -26,7 +26,6 @@ lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
@@ -39,7 +38,7 @@ lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
group_by_length: true
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
@@ -27,7 +27,6 @@ lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
@@ -40,7 +39,7 @@ lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
group_by_length: true
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
@@ -20,7 +20,6 @@ lora_target_modules:
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: mpt-alpaca-7b
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -22,7 +22,6 @@ lora_target_modules:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -28,7 +28,6 @@ lora_target_modules:
|
||||
- o_proj
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -22,7 +22,6 @@ lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
@@ -35,7 +34,7 @@ torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
group_by_length: true
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
@@ -23,7 +23,6 @@ lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -17,7 +17,6 @@ lora_target_modules:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -21,7 +21,6 @@ lora_target_modules:
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: redpajama-alpaca-3b
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -20,7 +20,6 @@ lora_target_modules:
|
||||
- mlp_down
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project: lora-replit
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -37,7 +37,6 @@ lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git
|
||||
bitsandbytes>=0.41.1
|
||||
bitsandbytes>=0.39.0
|
||||
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
|
||||
addict
|
||||
fire
|
||||
@@ -13,12 +13,9 @@ einops
|
||||
xformers
|
||||
optimum
|
||||
hf_transfer
|
||||
numba
|
||||
numpy==1.24.4
|
||||
# qlora things
|
||||
bert-score==0.3.13
|
||||
evaluate==0.4.0
|
||||
rouge-score==0.1.2
|
||||
scipy
|
||||
scikit-learn==1.2.2
|
||||
pynvml
|
||||
|
||||
@@ -18,17 +18,11 @@ from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import barrier, is_main_process
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.trainer import (
|
||||
calculate_total_num_steps,
|
||||
process_datasets_for_packing,
|
||||
setup_trainer,
|
||||
)
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
from axolotl.utils.validation import validate_config
|
||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||
|
||||
@@ -237,25 +231,12 @@ def train(
|
||||
cfg.pretraining_dataset,
|
||||
tokenizer,
|
||||
max_tokens=cfg.sequence_len,
|
||||
seed=cfg.seed or 42,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
eval_dataset = None
|
||||
|
||||
if is_main_process():
|
||||
# process on rank 0 first so it gets cached so other ranks load from cache
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
cfg, train_dataset, eval_dataset
|
||||
)
|
||||
barrier()
|
||||
if not is_main_process():
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
cfg, train_dataset, eval_dataset
|
||||
)
|
||||
barrier()
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||
|
||||
if cfg.debug or "debug" in kwargs:
|
||||
LOG.info("check_dataset_labels...")
|
||||
check_dataset_labels(
|
||||
@@ -269,13 +250,16 @@ def train(
|
||||
LOG.info("Finished preparing dataset. Exiting...")
|
||||
return
|
||||
|
||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||
|
||||
# Load the model and tokenizer
|
||||
LOG.info("loading model and (optionally) peft_config...")
|
||||
model, peft_config = load_model(cfg, tokenizer)
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
LOG.info("loading model and peft_config...")
|
||||
model, peft_config = load_model(
|
||||
cfg.base_model,
|
||||
cfg.base_model_config,
|
||||
cfg.model_type,
|
||||
tokenizer,
|
||||
cfg,
|
||||
adapter=cfg.adapter,
|
||||
)
|
||||
|
||||
if "merge_lora" in kwargs and cfg.adapter is not None:
|
||||
LOG.info("running merge of LoRA with base model")
|
||||
@@ -284,11 +268,7 @@ def train(
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info("saving merged model")
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir) / "merged"),
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
return
|
||||
|
||||
if cfg.inference:
|
||||
@@ -303,12 +283,10 @@ def train(
|
||||
return
|
||||
|
||||
if "shard" in kwargs:
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
return
|
||||
|
||||
trainer = setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||
)
|
||||
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
||||
|
||||
model.config.use_cache = False
|
||||
|
||||
@@ -327,7 +305,7 @@ def train(
|
||||
def terminate_handler(_, __, model):
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(
|
||||
@@ -367,15 +345,11 @@ def train(
|
||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||
if cfg.fsdp:
|
||||
trainer.save_model(cfg.output_dir)
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
if cfg.adapter == "lora" and cfg.relora_steps:
|
||||
model = model.merge_and_unload()
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
|
||||
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
from datasets import IterableDataset
|
||||
|
||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||
|
||||
@@ -18,9 +18,9 @@ from .prompt_tokenizers import PromptTokenizingStrategy
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
class TokenizedPromptDataset(Dataset):
|
||||
class TokenizedPromptDataset(IterableDataset):
|
||||
"""
|
||||
Dataset that returns tokenized prompts from a stream of text files.
|
||||
Iterable dataset that returns tokenized prompts from a stream of text files.
|
||||
Args:
|
||||
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
@@ -30,18 +30,19 @@ class TokenizedPromptDataset(Dataset):
|
||||
self,
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: IterableDataset,
|
||||
**kwargs,
|
||||
):
|
||||
self.prompt_tokenizer = prompt_tokenizer
|
||||
super().__init__(self.process(dataset).data, **kwargs)
|
||||
self.dataset = dataset
|
||||
|
||||
def process(self, dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = min(64, os.cpu_count())
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
remove_columns=features,
|
||||
def __iter__(self):
|
||||
features = self.dataset.features.keys()
|
||||
num_proc = os.cpu_count()
|
||||
return iter(
|
||||
self.dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
remove_columns=features,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -76,21 +77,14 @@ class ConstantLengthDataset(IterableDataset):
|
||||
self.tokens_dtype = torch.int64
|
||||
|
||||
def __iter__(self):
|
||||
buffer = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"position_ids": [],
|
||||
}
|
||||
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
buffer_len = 0
|
||||
for dataset in self.datasets:
|
||||
idx = 0
|
||||
iterator = iter(dataset)
|
||||
more_examples = True
|
||||
while more_examples:
|
||||
try:
|
||||
example = next(iterator)
|
||||
idx += 1
|
||||
except StopIteration:
|
||||
more_examples = False
|
||||
example = None
|
||||
@@ -112,9 +106,6 @@ class ConstantLengthDataset(IterableDataset):
|
||||
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
|
||||
: self.seq_length
|
||||
]
|
||||
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
|
||||
: self.seq_length
|
||||
]
|
||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||
if labels.size() == input_ids.size() and (
|
||||
attention_mask.size() == input_ids.size()
|
||||
@@ -123,7 +114,6 @@ class ConstantLengthDataset(IterableDataset):
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
else:
|
||||
LOG.warning(
|
||||
@@ -133,10 +123,8 @@ class ConstantLengthDataset(IterableDataset):
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"position_ids": [],
|
||||
}
|
||||
buffer_len = 0
|
||||
idx = 1
|
||||
|
||||
if example:
|
||||
# FIXME
|
||||
@@ -145,6 +133,11 @@ class ConstantLengthDataset(IterableDataset):
|
||||
input_ids = example["input_ids"]
|
||||
attention_mask = example["attention_mask"]
|
||||
labels = example["labels"]
|
||||
if (
|
||||
buffer["input_ids"]
|
||||
and input_ids[0] == self.tokenizer.bos_token_id
|
||||
):
|
||||
attention_mask[0] = 0
|
||||
|
||||
if add_concat_token:
|
||||
input_ids.append(self.concat_token_id)
|
||||
@@ -155,17 +148,13 @@ class ConstantLengthDataset(IterableDataset):
|
||||
input_ids, dtype=self.tokens_dtype
|
||||
)
|
||||
attention_mask_with_concat = torch.tensor(
|
||||
[idx * m for m in attention_mask], dtype=torch.int16
|
||||
attention_mask, dtype=self.tokens_dtype
|
||||
)
|
||||
labels_with_concat = torch.tensor(
|
||||
labels, dtype=self.tokens_dtype
|
||||
)
|
||||
position_ids = torch.arange(
|
||||
len(input_ids), dtype=self.tokens_dtype
|
||||
)
|
||||
|
||||
buffer["input_ids"].append(input_ids_with_concat)
|
||||
buffer["attention_mask"].append(attention_mask_with_concat)
|
||||
buffer["labels"].append(labels_with_concat)
|
||||
buffer["position_ids"].append(position_ids)
|
||||
buffer_len += len(input_ids)
|
||||
|
||||
@@ -8,18 +8,9 @@ import torch
|
||||
import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||
except ImportError:
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -88,16 +79,6 @@ def forward(
|
||||
dtype=torch.int32,
|
||||
device=qkv.device,
|
||||
)
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif position_ids.shape[0] == 1:
|
||||
# special handling using sample packing
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_q_lens = cu_q_lens.squeeze()
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
@@ -132,7 +113,6 @@ def forward(
|
||||
"b s (h d) -> b s h d",
|
||||
h=nheads,
|
||||
)
|
||||
|
||||
return (
|
||||
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
||||
None,
|
||||
|
||||
@@ -128,7 +128,6 @@ def xformers_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
# attn_bias=attention_mask,
|
||||
attn_bias=xformers.ops.LowerTriangularMask(),
|
||||
)
|
||||
attn_weights = None
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
"""
|
||||
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||
when they attend to each other within that sequence.
|
||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
||||
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
binary_mask = torch.where(
|
||||
mask != 0,
|
||||
torch.tensor(1).to(dtype),
|
||||
torch.tensor(0).to(dtype),
|
||||
)
|
||||
|
||||
# Create a block-diagonal mask.
|
||||
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||
|
||||
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
||||
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
||||
mask.device
|
||||
)
|
||||
|
||||
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
||||
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
||||
inverted_mask = 1.0 - masked_zero_one_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
|
||||
def hijack_expand_mask():
|
||||
import transformers
|
||||
|
||||
transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access
|
||||
_expand_mask
|
||||
)
|
||||
@@ -1,302 +0,0 @@
|
||||
# pylint: skip-file
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os.path
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Sequence
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import peft
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.relora")
|
||||
|
||||
|
||||
def reset_optimizer(optimizer: torch.optim.Optimizer):
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
param_state = optimizer.state[param]
|
||||
for key in param_state:
|
||||
if "qmap" in key:
|
||||
continue
|
||||
elif key == "step" and isinstance(param_state[key], int):
|
||||
param_state[key] = 0
|
||||
else:
|
||||
param_state[key] = torch.zeros_like(param_state[key])
|
||||
|
||||
|
||||
class ReLoRACallback(TrainerCallback):
|
||||
def __init__(self, cfg: DictDefault):
|
||||
self.relora_steps = cfg.relora_steps
|
||||
self.cpu_offload = cfg.relora_cpu_offload
|
||||
self.quantised = cfg.load_in_4bit or cfg.load_in_8bit
|
||||
self.last_full_model = cfg.base_model
|
||||
|
||||
assert os.path.exists(
|
||||
self.last_full_model
|
||||
), "for ReLORA base_model must be a local path"
|
||||
|
||||
self.num_lora_restarts = 0
|
||||
self.need_full_save = False
|
||||
|
||||
def on_step_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
model: peft.LoraModel,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
**_kwargs,
|
||||
):
|
||||
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
merge_and_save(
|
||||
model,
|
||||
self.last_full_model,
|
||||
checkpoint_folder,
|
||||
reinit=True,
|
||||
quantized=self.quantised,
|
||||
)
|
||||
reset_optimizer(optimizer)
|
||||
|
||||
if self.quantised:
|
||||
self.last_full_model = checkpoint_folder
|
||||
self.num_lora_restarts += 1
|
||||
|
||||
return control
|
||||
|
||||
def on_save(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
model: peft.LoraModel,
|
||||
**kwargs,
|
||||
):
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
)
|
||||
if (
|
||||
state.global_step >= self.relora_steps
|
||||
and state.global_step % self.relora_steps != 0
|
||||
):
|
||||
if self.quantised and self.last_full_model != checkpoint_folder:
|
||||
# ensure the latest full parameter save is in the latest checkpoint
|
||||
# folder, so that automatic pruning of checkpoints does not remove it
|
||||
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
|
||||
chunks = glob.glob(
|
||||
f"{self.last_full_model}/model*.safetensors"
|
||||
) + glob.glob(f"{self.last_full_model}/model*.index.json")
|
||||
for path in chunks:
|
||||
shutil.move(path, checkpoint_folder)
|
||||
self.last_full_model = checkpoint_folder
|
||||
else:
|
||||
model.model.save_pretrained(checkpoint_folder, save_safetensors=True)
|
||||
|
||||
return control
|
||||
|
||||
def on_log(
|
||||
self,
|
||||
_args: TrainingArguments,
|
||||
_state: TrainerState,
|
||||
control: TrainerControl,
|
||||
logs: Dict[str, float],
|
||||
**_kwargs,
|
||||
):
|
||||
logs["num_lora_restarts"] = self.num_lora_restarts
|
||||
return control
|
||||
|
||||
|
||||
class ReLoRAScheduler(LRScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
inner_schedule: LRScheduler,
|
||||
relora_steps: int,
|
||||
warmup_steps: int,
|
||||
min_lr_scale: float = 0.001,
|
||||
) -> None:
|
||||
self.inner_schedule = inner_schedule
|
||||
self.relora_steps = relora_steps
|
||||
self.warmup_steps = warmup_steps
|
||||
self.min_lr_scale = min_lr_scale
|
||||
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
||||
|
||||
def get_lr(self) -> float:
|
||||
self.inner_schedule.last_epoch = self.last_epoch
|
||||
|
||||
original = self.inner_schedule.get_lr()
|
||||
step = self.last_epoch
|
||||
if step < self.relora_steps:
|
||||
scale = 1
|
||||
else:
|
||||
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
|
||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||
if isinstance(original, Sequence):
|
||||
return [lr * scale for lr in original]
|
||||
else:
|
||||
return original * scale
|
||||
|
||||
|
||||
def sharded_paths(path: str, keys: List[str]) -> Dict[str, str]:
|
||||
model_name = "model.safetensors"
|
||||
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
|
||||
str(Path(path) / f"{model_name}.index.json")
|
||||
):
|
||||
model_name = "pytorch_model.bin"
|
||||
|
||||
index_path = str(Path(path) / f"{model_name}.index.json")
|
||||
if os.path.exists(index_path):
|
||||
data = json.load(open(index_path, "r"))
|
||||
return data["weight_map"]
|
||||
return {key + ".weight": model_name for key in keys}
|
||||
|
||||
|
||||
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer) -> torch.Tensor:
|
||||
if isinstance(layer, peft.tuners.lora.Linear8bitLt) or isinstance(
|
||||
layer, peft.tuners.lora.Linear4bit
|
||||
):
|
||||
adapter = layer.active_adapter
|
||||
return (
|
||||
peft.utils.transpose(
|
||||
layer.lora_B[adapter].weight @ layer.lora_A[adapter].weight,
|
||||
getattr(layer, "fan_in_fan_out", False),
|
||||
)
|
||||
* layer.scaling[adapter]
|
||||
)
|
||||
else:
|
||||
return layer.get_delta_weight()
|
||||
|
||||
|
||||
def merge_and_save(
|
||||
model: peft.LoraModel,
|
||||
model_src: str,
|
||||
model_dst: str,
|
||||
reinit: bool = False,
|
||||
quantized: bool = False,
|
||||
cpu_offload: bool = False,
|
||||
):
|
||||
key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
|
||||
|
||||
if not quantized:
|
||||
for key in key_list:
|
||||
try:
|
||||
_parent, target, _target_name = peft.utils._get_submodules(
|
||||
model.model, key
|
||||
)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if isinstance(target, peft.tuners.lora.LoraLayer):
|
||||
update = target.get_delta_weight(target.active_adapter).detach()
|
||||
target.weight.data += update
|
||||
|
||||
if reinit:
|
||||
for adapter_name in target.lora_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
for adapter_name in target.lora_embedding_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
return
|
||||
|
||||
os.makedirs(model_dst, exist_ok=True)
|
||||
shard_paths = sharded_paths(model_src, key_list)
|
||||
|
||||
unique_shards = list(set(shard_paths.values()))
|
||||
for shard_path in unique_shards:
|
||||
out_tensors = {}
|
||||
if shard_path.endswith(".safetensors"):
|
||||
in_tensors = st.load_file(str(Path(model_src) / shard_path))
|
||||
else:
|
||||
in_tensors = torch.load(Path(model_src) / shard_path)
|
||||
if "state_dict" in in_tensors:
|
||||
in_tensors = in_tensors["state_dict"]
|
||||
|
||||
for key in key_list:
|
||||
if (key + ".weight") not in shard_paths or shard_paths[
|
||||
key + ".weight"
|
||||
] != shard_path:
|
||||
continue
|
||||
|
||||
try:
|
||||
_parent, target, _target_name = peft.utils._get_submodules(
|
||||
model.model, key
|
||||
)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if isinstance(target, peft.tuners.lora.LoraLayer):
|
||||
orig_weight = in_tensors[key + ".weight"]
|
||||
old_dev = target.weight.device
|
||||
math_dev = "cpu" if cpu_offload else old_dev
|
||||
|
||||
update = lora_delta_weight(target).detach().to(math_dev)
|
||||
new_weight = orig_weight.to(math_dev) + update
|
||||
out_tensors[key + ".weight"] = new_weight
|
||||
|
||||
if reinit:
|
||||
for adapter_name in target.lora_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
for adapter_name in target.lora_embedding_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
|
||||
if isinstance(target, peft.tuners.lora.Linear4bit):
|
||||
target.weight = (
|
||||
bnb.nn.Params4bit(
|
||||
new_weight,
|
||||
requires_grad=False,
|
||||
compress_statistics=target.weight.compress_statistics,
|
||||
quant_type=target.weight.quant_type,
|
||||
)
|
||||
.cuda(None)
|
||||
.to(old_dev)
|
||||
)
|
||||
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
||||
target.weight = (
|
||||
bnb.nn.Int8Params(new_weight, requires_grad=False)
|
||||
.cuda(None)
|
||||
.to(old_dev)
|
||||
)
|
||||
else:
|
||||
target.weight.data = new_weight.to(old_dev)
|
||||
|
||||
for key in in_tensors:
|
||||
if key not in out_tensors:
|
||||
out_tensors[key] = in_tensors[key]
|
||||
del in_tensors
|
||||
|
||||
out_shard_name = shard_path
|
||||
if out_shard_name.startswith("pytorch_model"):
|
||||
out_shard_name = (
|
||||
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
|
||||
+ ".safetensors"
|
||||
)
|
||||
|
||||
shard_fn = str(Path(model_dst) / out_shard_name)
|
||||
LOG.info(f"saving tensors to {shard_fn}")
|
||||
st.save_file(out_tensors, shard_fn)
|
||||
del out_tensors
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if len(unique_shards) > 1:
|
||||
with open(str(Path(model_dst, "model.safetensors.index.json")), "w") as fd:
|
||||
json.dump({"metadata": {}, "weight_map": shard_paths}, fd)
|
||||
@@ -1,103 +0,0 @@
|
||||
"""
|
||||
Shared utils for the monkeypatches
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
def get_cu_seqlens(attn_mask):
|
||||
"""generate a cumulative sequence length mask for flash attention using attn mask"""
|
||||
if len(attn_mask.shape) == 1:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
|
||||
device = attn_mask.device
|
||||
results = []
|
||||
max_seq_lens = []
|
||||
|
||||
for row in attn_mask:
|
||||
# Exclude zeros to avoid adding their positions to the mask
|
||||
t_non_zeros = row[row != 0]
|
||||
# Find where the sequence number changes (including the first position)
|
||||
seq_change = torch.cat(
|
||||
[
|
||||
torch.tensor([1], dtype=torch.int32, device=device),
|
||||
t_non_zeros[1:] != t_non_zeros[:-1],
|
||||
]
|
||||
)
|
||||
# Get the indices where the sequence changes
|
||||
change_indices = torch.cat(
|
||||
[
|
||||
(seq_change == 1).nonzero(as_tuple=True)[0],
|
||||
torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),
|
||||
]
|
||||
)
|
||||
# Calculate the sequence lengths
|
||||
seq_lengths = change_indices[1:] - change_indices[:-1]
|
||||
# Calculate the length of the final sequence or padding
|
||||
final_seq_length = len(row) - change_indices[-1]
|
||||
# Append the length of the final sequence or padding to seq_lengths
|
||||
if final_seq_length.item():
|
||||
seq_lengths = torch.cat(
|
||||
[
|
||||
seq_lengths,
|
||||
torch.tensor(
|
||||
[final_seq_length.item()], dtype=torch.int32, device=device
|
||||
),
|
||||
]
|
||||
)
|
||||
# Calculate the cumulative sequence lengths
|
||||
cu_seqlens = torch.cat(
|
||||
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
||||
)
|
||||
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
results.append(cu_seqlens)
|
||||
max_seq_lens.append(max_seq_len)
|
||||
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
|
||||
|
||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
||||
if len(position_ids.shape) == 1:
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
device = position_ids.device
|
||||
results = []
|
||||
max_seq_lens = []
|
||||
|
||||
for row in position_ids:
|
||||
# Count the number of consecutive zeros from the right side
|
||||
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
|
||||
|
||||
# Adjust the row to exclude padding
|
||||
adjusted_row = row[:-padding_length] if padding_length else row.clone()
|
||||
|
||||
# Find where the position resets to 0 (indicating a new sequence)
|
||||
seq_starts = torch.cat(
|
||||
[
|
||||
torch.tensor([True], dtype=torch.bool, device=device),
|
||||
adjusted_row[1:] == 0,
|
||||
]
|
||||
)
|
||||
# Get the indices where the sequence starts
|
||||
start_indices = torch.cat(
|
||||
[
|
||||
(seq_starts).nonzero(as_tuple=True)[0],
|
||||
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
||||
]
|
||||
)
|
||||
# Calculate the sequence lengths
|
||||
seq_lengths = start_indices[1:] - start_indices[:-1]
|
||||
# Calculate the cumulative sequence lengths
|
||||
cu_seqlens = torch.cat(
|
||||
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
||||
)
|
||||
# Append the padding length to the cumulative sequence lengths
|
||||
if padding_length:
|
||||
cu_seqlens = torch.cat(
|
||||
[cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)]
|
||||
)
|
||||
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
results.append(cu_seqlens)
|
||||
max_seq_lens.append(max_seq_len)
|
||||
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
@@ -66,11 +66,7 @@ class SystemDataPrompter(AlpacaPrompter):
|
||||
) -> Generator[str, None, None]:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
formatted_sys_prompt = (
|
||||
self.system_format.format(system=system)
|
||||
if system and self.system_format
|
||||
else ""
|
||||
)
|
||||
formatted_sys_prompt = f"### System:\n{system}\n\n" if system else ""
|
||||
if input:
|
||||
res = formatted_sys_prompt + self.turn_format.format(
|
||||
instruction=instruction, input=input
|
||||
@@ -90,20 +86,12 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter):
|
||||
"""
|
||||
|
||||
def match_prompt_style(self):
|
||||
# pylint: disable=duplicate-code
|
||||
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
||||
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
|
||||
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
|
||||
if self.prompt_style == PromptStyle.CHAT.value:
|
||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
||||
self.system_format = "SYSTEM: {system}\n"
|
||||
if self.prompt_style == PromptStyle.CHATML.value:
|
||||
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.turn_no_input_format = (
|
||||
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
||||
|
||||
|
||||
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
|
||||
@@ -149,12 +137,3 @@ def load_open_orca(tokenizer, cfg):
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_open_orca_chatml(tokenizer, cfg):
|
||||
return OpenOrcaPromptTokenizingStrategy(
|
||||
OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
@@ -1,205 +0,0 @@
|
||||
"""
|
||||
Prompt Strategy for finetuning Llama2 chat models
|
||||
see also https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L213 for ma reference implementation.
|
||||
|
||||
This implementation is based on the Vicuna PR and the fastchat repo, see also:
|
||||
https://github.com/lm-sys/FastChat/blob/cdd7730686cb1bf9ae2b768ee171bdf7d1ff04f3/fastchat/conversation.py#L847
|
||||
|
||||
Use dataset type: "llama2_chat" in conig.yml to use this prompt style.
|
||||
|
||||
E.g. in the config.yml:
|
||||
```
|
||||
datasets:
|
||||
- path: llama_finetune_train.jsonl
|
||||
type: llama2_chat
|
||||
```
|
||||
|
||||
The dataset itself should look like this:
|
||||
```
|
||||
{'conversations':[{"from": "human", "value": "Who are you?"}, {"from": "gpt", "value": "I am Vicuna"},...]}
|
||||
```
|
||||
in a jsonl file. The first message should be from the human, the second from gpt.
|
||||
For a custom system message, the first "from" can be "system" (followed by alternating "human" and "gpt" turns).
|
||||
|
||||
Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing!
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generator, List, Sequence
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
|
||||
|
||||
@dataclass
|
||||
class Llama2ChatConversation:
|
||||
"""A class that manages prompt templates and keeps all conversation history.
|
||||
copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py"""
|
||||
|
||||
name: str = "llama2"
|
||||
# The system prompt
|
||||
system: str = (
|
||||
"[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
|
||||
)
|
||||
roles: Sequence[str] = ("[INST]", "[/INST]")
|
||||
messages: List[List[str]] = field(default_factory=list)
|
||||
offset: int = 0
|
||||
sep = " "
|
||||
sep2 = " </s><s>"
|
||||
stop_token_ids = [2]
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
"""Get the prompt for generation."""
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = ""
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if (i == len(self.messages) - 1) and (role == self.roles[0]):
|
||||
# last message is from user (due to length),
|
||||
# return prompt without it for training
|
||||
return ret
|
||||
if i == 0:
|
||||
ret += self.system + message.strip()
|
||||
else:
|
||||
ret += role + " " + message.strip() + seps[i % 2]
|
||||
return ret
|
||||
|
||||
def append_message(self, role: str, message: str):
|
||||
"""Append a new message."""
|
||||
self.messages.append([role, message])
|
||||
|
||||
|
||||
class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for ShareGPT prompts.
|
||||
adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.sequence_len = 4096
|
||||
self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
||||
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
conv = next(self.prompter.build_prompt(prompt))
|
||||
conversation_str = conv.get_prompt()
|
||||
|
||||
# Tokenize conversations
|
||||
input_ids = self.tokenizer(
|
||||
conversation_str,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=self.sequence_len,
|
||||
truncation=True,
|
||||
).input_ids[0]
|
||||
target = input_ids.clone()
|
||||
|
||||
# Mask targets. Only compute loss on the assistant outputs.
|
||||
sep = conv.roles[1]
|
||||
|
||||
total_len = int(target.ne(self.tokenizer.pad_token_id).sum())
|
||||
|
||||
turns = conversation_str.split(conv.sep2)
|
||||
cur_len = 1
|
||||
target[:cur_len] = IGNORE_TOKEN_ID
|
||||
for turn in turns:
|
||||
if turn == "":
|
||||
break
|
||||
turn_len = len(self.tokenizer(turn).input_ids)
|
||||
|
||||
parts = turn.split(sep)
|
||||
if len(parts) != 2:
|
||||
break
|
||||
parts[0] += sep
|
||||
# "-1" is hardcoded for the LLaMA tokenizer to make the offset correct.
|
||||
instruction_len = len(self.tokenizer(parts[0]).input_ids) - 1
|
||||
|
||||
# Ignore the user instructions
|
||||
target[cur_len - 1 : cur_len + instruction_len] = IGNORE_TOKEN_ID
|
||||
cur_len += turn_len + 2 # due to length of role token
|
||||
|
||||
target[cur_len:] = IGNORE_TOKEN_ID
|
||||
|
||||
if cur_len < self.sequence_len:
|
||||
if cur_len != total_len:
|
||||
target[:] = IGNORE_TOKEN_ID
|
||||
logging.warning(
|
||||
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
||||
f" (ignored)"
|
||||
)
|
||||
|
||||
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).tolist()
|
||||
input_ids = input_ids.tolist()
|
||||
target = target.tolist()
|
||||
# this is a fix for the tokenizer which tokenizes [ differently with eos tokens and
|
||||
# follows the original llama implementation
|
||||
for i in range(2, total_len - 2):
|
||||
if input_ids[i] == 29961:
|
||||
input_ids[i] = 518
|
||||
if target[i] == 29961:
|
||||
target[i] = 518
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"labels": target,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
|
||||
class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A prompter that generates prompts for Llama2 models.
|
||||
"""
|
||||
|
||||
system_prompt = (
|
||||
"[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
|
||||
)
|
||||
|
||||
def build_prompt(self, source) -> Generator[Llama2ChatConversation, None, None]:
|
||||
# see https://github.com/lm-sys/FastChat/blob/da0641e567cf93756b0978ab5a6b092e96f06240/fastchat/train/train.py#L78
|
||||
source = source["conversations"] # fix data structure for datasets
|
||||
|
||||
# if system prompt provided, use it
|
||||
if source[0]["from"] == "system":
|
||||
system = f"[INST] <<SYS>>\n{source[0]['value']}\n<</SYS>>\n\n"
|
||||
source = source[1:]
|
||||
else:
|
||||
system = self.system_prompt
|
||||
|
||||
conv = Llama2ChatConversation(system=system)
|
||||
|
||||
if len(source) < 2:
|
||||
# If there isn't a back and forth conversation, ignore it
|
||||
# also happens on the data splitting leaving empty conversations
|
||||
raise IndexError
|
||||
|
||||
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
||||
|
||||
if roles[source[0]["from"]] != conv.roles[0]:
|
||||
# Skip the first one if it is not from human
|
||||
source = source[1:]
|
||||
|
||||
conv.messages = [] # pylint: disable=R0801
|
||||
for j, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
if sentence["value"]:
|
||||
conv.append_message(role, sentence["value"])
|
||||
yield conv
|
||||
|
||||
|
||||
def load(tokenizer, cfg) -> LLama2ChatTokenizingStrategy:
|
||||
return LLama2ChatTokenizingStrategy(
|
||||
Llama2ChatPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
@@ -1,46 +0,0 @@
|
||||
"""
|
||||
Prompt Strategy for finetuning Orca Mini (v2) models
|
||||
see also https://huggingface.co/psmathur/orca_mini_v2_7b for more information
|
||||
|
||||
Use dataset type: orcamini in conig.yml to use this prompt style.
|
||||
|
||||
Compared to the alpaca_w_system.open_orca dataset type,
|
||||
this one specifies the system prompt with "### System:".
|
||||
|
||||
Not suited/tested for multiple-turn conversations without further adjustments.
|
||||
"""
|
||||
from typing import Generator, Union
|
||||
|
||||
from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter
|
||||
|
||||
|
||||
class OrcaMiniPrompter(AlpacaPrompter):
|
||||
"""Adjusted Prompter for Orca Mini (v2) datasets"""
|
||||
|
||||
def match_prompt_style(self):
|
||||
self.turn_no_input_format = (
|
||||
"### System:\n{system}\n\n### User:\n{instruction}\n\n### Response:\n"
|
||||
)
|
||||
|
||||
def build_prompt_w_system(
|
||||
self,
|
||||
system: str,
|
||||
instruction: str,
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
res = self.turn_no_input_format.format(system=system, instruction=instruction)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
yield res
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return OpenOrcaPromptTokenizingStrategy(
|
||||
OrcaMiniPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
@@ -16,7 +16,6 @@ class PromptStyle(Enum):
|
||||
|
||||
INSTRUCT = "instruct"
|
||||
CHAT = "chat"
|
||||
CHATML = "chatml"
|
||||
|
||||
|
||||
class AlpacaPrompter:
|
||||
@@ -26,7 +25,6 @@ class AlpacaPrompter:
|
||||
|
||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
||||
system_format: str
|
||||
turn_format: str
|
||||
turn_no_input_format: str
|
||||
prompt_style: Optional[PromptStyle] = None
|
||||
@@ -36,23 +34,14 @@ class AlpacaPrompter:
|
||||
self.match_prompt_style()
|
||||
|
||||
def match_prompt_style(self):
|
||||
# pylint: disable=duplicate-code
|
||||
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
||||
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
self.turn_no_input_format = (
|
||||
"### Instruction:\n{instruction}\n\n### Response:\n"
|
||||
)
|
||||
self.system_format = "### System:\n{system}\n\n"
|
||||
if self.prompt_style == PromptStyle.CHAT.value:
|
||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
||||
self.system_format = "SYSTEM: {system}\n"
|
||||
if self.prompt_style == PromptStyle.CHATML.value:
|
||||
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.turn_no_input_format = (
|
||||
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
@@ -271,11 +260,6 @@ class Conversation:
|
||||
self.messages.append([role, message])
|
||||
|
||||
|
||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||
"Role did not alternate between turns (gpt and human). Please check your data."
|
||||
)
|
||||
|
||||
|
||||
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A prompter that generates prompts for the ShareGPT
|
||||
@@ -332,7 +316,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
conv.messages = []
|
||||
for j, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
assert role == conv.roles[j % 2]
|
||||
conv.append_message(role, sentence["value"])
|
||||
|
||||
for part in conv.get_prompt():
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
"""Benchmarking and measurement utilities"""
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
|
||||
|
||||
def gpu_memory_usage(device):
|
||||
if isinstance(device, torch.device):
|
||||
device = device.index
|
||||
if isinstance(device, str) and device.startswith("cuda:"):
|
||||
device = int(device[5:])
|
||||
|
||||
# NB torch.cuda.memory_usage returns zero so we use lower level api
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
return info.used / 1024.0**3
|
||||
|
||||
|
||||
def log_gpu_memory_usage(log, msg, device):
|
||||
log.info(
|
||||
f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
|
||||
)
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Callbacks for Trainer class"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
@@ -12,10 +11,6 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
||||
"""Callback to save the PEFT adapter"""
|
||||
@@ -33,9 +28,7 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
||||
)
|
||||
|
||||
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
||||
kwargs["model"].save_pretrained(
|
||||
peft_model_path, save_safetensors=args.save_safetensors
|
||||
)
|
||||
kwargs["model"].save_pretrained(peft_model_path)
|
||||
|
||||
return control
|
||||
|
||||
@@ -74,25 +67,3 @@ class SaveBetterTransformerModelCallback(
|
||||
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
||||
control.should_save = False
|
||||
return control
|
||||
|
||||
|
||||
class PrintGPUStatsCallback(
|
||||
TrainerCallback
|
||||
): # pylint: disable=too-few-public-methods disable=unused-argument
|
||||
"""Callback to print GPU utilization"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.logged = False
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if not self.logged:
|
||||
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
||||
self.logged = True
|
||||
return control
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
"""
|
||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSeq2Seq:
|
||||
"""
|
||||
Data collator that will dynamically pad the inputs received, as well as the labels and position_ids
|
||||
|
||||
Args:
|
||||
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
||||
The tokenizer used for encoding the data.
|
||||
model ([`PreTrainedModel`]):
|
||||
The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
|
||||
prepare the *decoder_input_ids*
|
||||
|
||||
This is useful when using *label_smoothing* to avoid calculating loss twice.
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||
among:
|
||||
|
||||
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence is provided).
|
||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||
acceptable input length for the model if that argument is not provided.
|
||||
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
pad_to_multiple_of (`int`, *optional*):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||
7.5 (Volta).
|
||||
label_pad_token_id (`int`, *optional*, defaults to -100):
|
||||
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||
return_tensors (`str`):
|
||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
model: Optional[Any] = None
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
label_pad_token_id: int = -100
|
||||
position_pad_token_id: int = 0
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
labels = None
|
||||
if return_tensors is None:
|
||||
return_tensors = self.return_tensors
|
||||
|
||||
for feature_name, pad_token_id in [
|
||||
("labels", self.label_pad_token_id),
|
||||
("position_ids", self.position_pad_token_id),
|
||||
]:
|
||||
feat = (
|
||||
[feature[feature_name] for feature in features]
|
||||
if feature_name in features[0].keys()
|
||||
else None
|
||||
)
|
||||
labels = feat if feat and feature_name == "labels" else labels
|
||||
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
|
||||
# same length to return tensors.
|
||||
if feat is not None:
|
||||
max_feature_length = max(len(l) for l in feat) # noqa: E741
|
||||
if self.pad_to_multiple_of is not None:
|
||||
max_feature_length = (
|
||||
(max_feature_length + self.pad_to_multiple_of - 1)
|
||||
// self.pad_to_multiple_of
|
||||
* self.pad_to_multiple_of
|
||||
)
|
||||
|
||||
padding_side = self.tokenizer.padding_side
|
||||
for feature in features:
|
||||
remainder = [pad_token_id] * (
|
||||
max_feature_length - len(feature[feature_name])
|
||||
)
|
||||
if isinstance(feature[feature_name], list):
|
||||
feature[feature_name] = (
|
||||
feature[feature_name] + remainder
|
||||
if padding_side == "right"
|
||||
else remainder + feature[feature_name]
|
||||
)
|
||||
elif padding_side == "right":
|
||||
feature[feature_name] = np.concatenate(
|
||||
[feature[feature_name], remainder]
|
||||
).astype(np.int64)
|
||||
else:
|
||||
feature[feature_name] = np.concatenate(
|
||||
[remainder, feature[feature_name]]
|
||||
).astype(np.int64)
|
||||
|
||||
features = self.tokenizer.pad(
|
||||
features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
|
||||
# prepare decoder_input_ids
|
||||
if (
|
||||
labels is not None
|
||||
and self.model is not None
|
||||
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
|
||||
):
|
||||
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
|
||||
labels=features["labels"]
|
||||
)
|
||||
features["decoder_input_ids"] = decoder_input_ids
|
||||
|
||||
return features
|
||||
@@ -1,19 +1,13 @@
|
||||
"""Module containing data utilities"""
|
||||
import functools
|
||||
import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Union
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from datasets import (
|
||||
Dataset,
|
||||
DatasetDict,
|
||||
concatenate_datasets,
|
||||
load_dataset,
|
||||
load_from_disk,
|
||||
)
|
||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
@@ -41,7 +35,6 @@ from axolotl.prompters import (
|
||||
ShareGPTPrompter,
|
||||
SummarizeTLDRPrompter,
|
||||
)
|
||||
from axolotl.utils.distributed import barrier, is_main_process
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -116,7 +109,6 @@ def load_tokenized_prepared_datasets(
|
||||
local_path = Path(d.path)
|
||||
if local_path.exists():
|
||||
if local_path.is_dir():
|
||||
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
||||
ds = load_dataset(
|
||||
d.path,
|
||||
name=d.name,
|
||||
@@ -270,12 +262,20 @@ def load_tokenized_prepared_datasets(
|
||||
raise ValueError(
|
||||
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
||||
)
|
||||
LOG.info("merging datasets")
|
||||
dataset = concatenate_datasets(datasets)
|
||||
LOG.info("tokenizing, merging, and shuffling master dataset")
|
||||
|
||||
if len(datasets) > 1:
|
||||
LOG.info("shuffle merged datasets")
|
||||
dataset = dataset.shuffle(seed=seed)
|
||||
samples: List[int] = []
|
||||
chunk_size = 1000
|
||||
for d in datasets:
|
||||
d_iter = iter(d)
|
||||
while True:
|
||||
chunk = list(itertools.islice(d_iter, chunk_size))
|
||||
if not chunk:
|
||||
break
|
||||
samples.extend(chunk)
|
||||
|
||||
LOG.info("shuffle")
|
||||
dataset = Dataset.from_list(samples).shuffle(seed=seed)
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||
dataset.save_to_disk(prepared_ds_path)
|
||||
@@ -374,12 +374,11 @@ def load_prepare_datasets(
|
||||
dataset = Dataset.from_list(list(constant_len_dataset))
|
||||
|
||||
# filter out bad data
|
||||
# TODO convert to dataset.filter(...)
|
||||
dataset = Dataset.from_list(
|
||||
[
|
||||
d
|
||||
for d in dataset
|
||||
if len(d["input_ids"]) <= cfg.sequence_len
|
||||
if len(d["input_ids"]) < cfg.sequence_len
|
||||
and len(d["input_ids"]) > 0
|
||||
and len(d["input_ids"]) == len(d["attention_mask"])
|
||||
and len(d["input_ids"]) == len(d["labels"])
|
||||
@@ -414,51 +413,7 @@ def load_prepare_datasets(
|
||||
)
|
||||
|
||||
if cfg.val_set_size:
|
||||
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
||||
to_hash_train = (
|
||||
dataset._fingerprint # pylint: disable=protected-access
|
||||
+ "|"
|
||||
+ str(cfg.val_set_size)
|
||||
+ "|"
|
||||
+ "train"
|
||||
+ "|"
|
||||
+ str(cfg.seed or 42)
|
||||
)
|
||||
to_hash_test = (
|
||||
dataset._fingerprint # pylint: disable=protected-access
|
||||
+ "|"
|
||||
+ str(cfg.val_set_size)
|
||||
+ "|"
|
||||
+ "test"
|
||||
+ "|"
|
||||
+ str(cfg.seed or 42)
|
||||
)
|
||||
train_fingerprint = hashlib.md5(
|
||||
to_hash_train.encode(), usedforsecurity=False
|
||||
).hexdigest()
|
||||
test_fingerprint = hashlib.md5(
|
||||
to_hash_test.encode(), usedforsecurity=False
|
||||
).hexdigest()
|
||||
|
||||
if is_main_process():
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed or 42,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
barrier()
|
||||
if not is_main_process():
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed or 42,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
barrier()
|
||||
|
||||
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
else:
|
||||
|
||||
@@ -1,288 +0,0 @@
|
||||
# pylint: skip-file
|
||||
import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Callable, List, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from torch.utils.data import DistributedSampler, Sampler
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.dataloader")
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_check(a: np.ndarray, c: int, n: int):
|
||||
# First-fit-decreasing bin packing
|
||||
# Check if a[] could fit in n bins with capacity c
|
||||
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
||||
|
||||
a = np.sort(a)[::-1]
|
||||
bins = np.full((n,), c, dtype=a.dtype)
|
||||
for size in a:
|
||||
not_found = True
|
||||
for idx in range(n):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
not_found = False
|
||||
break
|
||||
|
||||
if not_found:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
||||
# First-fit-decreasing bin packing (with result return)
|
||||
|
||||
indices = np.argsort(a)[::-1]
|
||||
a = a[indices]
|
||||
|
||||
bins: List[Any] = []
|
||||
bins_result: List[Any] = []
|
||||
for a_id, size in enumerate(a):
|
||||
add_new = True
|
||||
for idx in range(len(bins)):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
bins_result[idx].append(indices[a_id] + start_index)
|
||||
add_new = False
|
||||
break
|
||||
|
||||
if add_new:
|
||||
bins.append(c - size)
|
||||
bins_result.append([indices[a_id] + start_index])
|
||||
|
||||
return bins_result, len(a)
|
||||
|
||||
|
||||
@numba.njit
|
||||
def allocate(
|
||||
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
||||
):
|
||||
"""
|
||||
:param lengths: array of lengths of each sample
|
||||
:param lengths_cumsum: cumulative sum of consecutive lengths
|
||||
:param rank: rank for this process
|
||||
:param c: length of tokens per batch
|
||||
:param n: number of ranks
|
||||
:return:
|
||||
"""
|
||||
# Dynamic batch allocator, similar to Multifit
|
||||
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
||||
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
||||
|
||||
s = 0
|
||||
start_index = 0
|
||||
result = []
|
||||
result_totseqs = []
|
||||
|
||||
while True:
|
||||
# binary search [left, right)
|
||||
left = 1
|
||||
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
||||
|
||||
while right - left > 1:
|
||||
mid = (left + right) // 2
|
||||
if ffd_check(lengths[start_index : start_index + mid], c, n):
|
||||
left = mid
|
||||
else:
|
||||
right = mid
|
||||
|
||||
# use length left
|
||||
batch, tot_seqs = ffd_with_result(
|
||||
lengths[start_index : start_index + left], c, start_index
|
||||
)
|
||||
if len(batch) < n:
|
||||
break
|
||||
|
||||
start_index += left
|
||||
s = lengths_cumsum[start_index - 1]
|
||||
|
||||
# add local rank
|
||||
result.append(batch[rank])
|
||||
# add total seqs for all ranks
|
||||
result_totseqs.append(tot_seqs)
|
||||
# yield batch[rank], tot_seqs, s, len(result) * c * n
|
||||
return result, result_totseqs, s, len(result) * c * n
|
||||
|
||||
|
||||
def chunk(iterable, n):
|
||||
"""
|
||||
Chunk data into tuples of length n
|
||||
"""
|
||||
# batched('ABCDEFG', 3) --> ABC DEF G
|
||||
if n < 1:
|
||||
raise ValueError("n must be at least one")
|
||||
it = iter(iterable)
|
||||
while batch := tuple(itertools.islice(it, n)):
|
||||
yield batch
|
||||
|
||||
|
||||
def hash_indices(lst: List[int]) -> str:
|
||||
# Convert the list of integers to a string representation
|
||||
concatenated = ",".join(map(str, lst))
|
||||
|
||||
# Generate the hash
|
||||
sha256 = hashlib.sha256()
|
||||
sha256.update(concatenated.encode())
|
||||
|
||||
return sha256.hexdigest()
|
||||
|
||||
|
||||
class MultipackDistributedDataloader:
|
||||
"""Unpadded data loading using Multipack.
|
||||
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
|
||||
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Any,
|
||||
collate_fn: Callable,
|
||||
seq_max_length: int = 2048,
|
||||
batch_size: int = 1,
|
||||
sampler: Union[Sampler, DistributedSampler] = None,
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
sample_packing_seq_len_multiplier: int = 1,
|
||||
device_count: int = 1,
|
||||
):
|
||||
# Dataset
|
||||
self.dataset = dataset
|
||||
self.lengths = (
|
||||
dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
)
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
assert batch_size % sample_packing_seq_len_multiplier == 0
|
||||
assert batch_size >= sample_packing_seq_len_multiplier
|
||||
self.sampler = sampler
|
||||
self.batch_size = batch_size
|
||||
self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
|
||||
self.seq_max_length = seq_max_length
|
||||
self.batch_max_length = batch_size * seq_max_length
|
||||
self.collate_fn = collate_fn
|
||||
|
||||
self.num_replicas = 1
|
||||
self.rank = 0
|
||||
|
||||
# statistics
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
self.device_count = device_count
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
LOG.info("generating packed batches")
|
||||
if self.sampler:
|
||||
indices = [idx for idx in self.sampler]
|
||||
else:
|
||||
indices = range(0, len(self.dataset))
|
||||
|
||||
LOG.info(hash_indices(indices))
|
||||
lengths = self.lengths[indices]
|
||||
lengths_cumsum = np.cumsum(lengths)
|
||||
|
||||
batches, totseqs, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=self.rank,
|
||||
# c=self.batch_max_length,
|
||||
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
||||
n=self.num_replicas,
|
||||
)
|
||||
|
||||
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
||||
|
||||
# statistics
|
||||
if set_stats:
|
||||
self.eff_total_used += total_used
|
||||
self.eff_total_slots += total_slots
|
||||
|
||||
return batches, totseqs
|
||||
|
||||
def __iter__(self):
|
||||
if hasattr(self.sampler, "set_epoch"):
|
||||
new_epoch = self.sampler.epoch + 1
|
||||
self.sampler.set_epoch(new_epoch)
|
||||
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||
all_batches, _ = self.generate_batches(set_stats=True)
|
||||
features = self.dataset.features.keys()
|
||||
len_remaining = self._len_est()
|
||||
for batches in chunk(
|
||||
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
|
||||
):
|
||||
chunked_data = []
|
||||
attn_mask_cum_idx = 0
|
||||
for batch in batches:
|
||||
concatenated = {}
|
||||
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
||||
for feature in features:
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
||||
for idx, item in enumerate(batched_data)
|
||||
if feature in item
|
||||
]
|
||||
attn_mask_cum_idx += len(batched_data)
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature])
|
||||
for item in batched_data
|
||||
if feature in item
|
||||
]
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
chunked_data.append(concatenated)
|
||||
yield self.collate_fn(chunked_data)
|
||||
len_remaining -= 1
|
||||
if not len_remaining:
|
||||
return
|
||||
|
||||
def _len_est(self):
|
||||
lengths_sum = np.sum(self.lengths)
|
||||
lengths_sum_per_device = lengths_sum // self.device_count
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||
)
|
||||
|
||||
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||
return (
|
||||
math.floor(
|
||||
0.99
|
||||
* lengths_sum_per_device
|
||||
/ self.packing_efficiency_estimate
|
||||
// self.seq_max_length
|
||||
// self.batch_size
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
|
||||
# the same share of total tokens
|
||||
# if not self.eff_total_used:
|
||||
# batches, _ = self.generate_batches(set_stats=True)
|
||||
# LOG.info(
|
||||
# f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
# f"actual packing efficiency: {self.efficiency()}"
|
||||
# )
|
||||
return max(1, self._len_est())
|
||||
|
||||
def len_w_stats(self):
|
||||
if not self.eff_total_used:
|
||||
batches, _ = self.generate_batches(set_stats=True)
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
f"actual packing efficiency: {self.efficiency()}"
|
||||
)
|
||||
return max(1, self._len_est())
|
||||
|
||||
def efficiency(self):
|
||||
return self.eff_total_used / self.eff_total_slots
|
||||
@@ -1,41 +0,0 @@
|
||||
"""
|
||||
utility helpers for distributed checks
|
||||
"""
|
||||
import torch.distributed as dist
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerate = None # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def load_accelerate():
|
||||
global accelerate # pylint: disable=global-statement
|
||||
accelerate = Accelerator()
|
||||
|
||||
|
||||
def is_distributed():
|
||||
"""
|
||||
Check if distributed training is initialized.
|
||||
"""
|
||||
global accelerate # pylint: disable=global-statement
|
||||
if not accelerate:
|
||||
accelerate = Accelerator()
|
||||
return dist.is_available() and dist.is_initialized()
|
||||
|
||||
|
||||
def barrier():
|
||||
"""
|
||||
Acts as a barrier to wait for all processes. This ensures that all processes
|
||||
reach the barrier before proceeding further.
|
||||
"""
|
||||
if is_distributed():
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
"""
|
||||
Check if the current process is the main process.
|
||||
If not in distributed mode, always return True.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return True
|
||||
return dist.get_rank() == 0
|
||||
@@ -22,7 +22,6 @@ from transformers import ( # noqa: F401
|
||||
)
|
||||
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -37,26 +36,20 @@ def load_tokenizer(
|
||||
tokenizer_type,
|
||||
cfg,
|
||||
):
|
||||
tokenizer_kwargs = {}
|
||||
use_fast = True # this is the default
|
||||
if cfg.tokenizer_use_fast is not None:
|
||||
use_fast = cfg.tokenizer_use_fast
|
||||
if cfg.tokenizer_legacy is not None:
|
||||
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
||||
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
||||
if tokenizer_type:
|
||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
@@ -84,22 +77,17 @@ def load_tokenizer(
|
||||
|
||||
|
||||
def load_model(
|
||||
cfg, tokenizer
|
||||
): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
|
||||
):
|
||||
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
"""
|
||||
Load a model for a given configuration and tokenizer.
|
||||
Load a model from a base model and a model type.
|
||||
"""
|
||||
base_model = cfg.base_model
|
||||
base_model_config = cfg.base_model_config
|
||||
model_type = cfg.model_type
|
||||
adapter = cfg.adapter
|
||||
|
||||
# TODO refactor as a kwarg
|
||||
load_in_8bit = cfg.load_in_8bit
|
||||
cfg.is_llama_derived_model = (
|
||||
"llama" in base_model
|
||||
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
||||
or cfg.is_llama_derived_model
|
||||
cfg.is_llama_derived_model = "llama" in base_model or (
|
||||
cfg.model_type and "llama" in cfg.model_type.lower()
|
||||
)
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||
@@ -144,14 +132,6 @@ def load_model(
|
||||
LOG.info("patching with xpos rope")
|
||||
replace_llama_rope_with_xpos_rope()
|
||||
|
||||
if cfg.is_llama_derived_model and (
|
||||
cfg.max_packed_sequence_len or cfg.sample_packing
|
||||
):
|
||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||
|
||||
LOG.info("patching _expand_mask")
|
||||
hijack_expand_mask()
|
||||
|
||||
if cfg.bf16 or cfg.bfloat16:
|
||||
torch_dtype = torch.bfloat16
|
||||
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
||||
@@ -235,15 +215,14 @@ def load_model(
|
||||
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
config = LlamaConfig.from_pretrained(
|
||||
base_model_config, rope_scaling=cfg.rope_scaling
|
||||
)
|
||||
config = LlamaConfig.from_pretrained(base_model_config)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map="auto" if cfg.world_size == 1 else cfg.device_map,
|
||||
**model_kwargs,
|
||||
)
|
||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||
@@ -278,6 +257,7 @@ def load_model(
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -308,6 +288,7 @@ def load_model(
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -321,6 +302,7 @@ def load_model(
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -342,9 +324,6 @@ def load_model(
|
||||
)
|
||||
model.config.max_position_embeddings = cfg.sequence_len
|
||||
|
||||
if model.device.type == "cuda":
|
||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||
|
||||
if not cfg.gptq and (
|
||||
(cfg.adapter == "lora" and load_in_8bit)
|
||||
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
||||
@@ -381,9 +360,6 @@ def load_model(
|
||||
module.scales = module.scales.half()
|
||||
module.bias = module.bias.half()
|
||||
|
||||
if model.device.type == "cuda":
|
||||
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
||||
|
||||
if (
|
||||
torch.cuda.device_count() > 1
|
||||
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
||||
@@ -415,8 +391,6 @@ def load_adapter(model, cfg, adapter):
|
||||
|
||||
if adapter is None:
|
||||
return model, None
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
if adapter in ["lora", "qlora"]:
|
||||
return load_lora(model, cfg)
|
||||
if adapter == "llama-adapter":
|
||||
|
||||
@@ -1,34 +1,26 @@
|
||||
"""Module containing the Trainer class and related functions"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import numpy as np
|
||||
import torch.cuda
|
||||
import transformers
|
||||
from datasets import Dataset, set_caching_enabled
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils.callbacks import (
|
||||
PrintGPUStatsCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
SavePeftModelCallback,
|
||||
)
|
||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||
from axolotl.utils.schedulers import (
|
||||
InterpolatingLogScheduler,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
@@ -37,68 +29,6 @@ from axolotl.utils.schedulers import (
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def weighted_cross_entropy(
|
||||
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
|
||||
):
|
||||
# Flatten the logits, labels, and weights tensors
|
||||
logits = logits.view(
|
||||
-1, logits.size(-1)
|
||||
) # logits becomes of shape [batch_size*sequence_length, vocab_size]
|
||||
labels = labels.view(-1) # labels becomes of shape [batch_size*sequence_length]
|
||||
weights = weights.view(-1) # weights becomes of shape [batch_size*sequence_length]
|
||||
|
||||
# Compute the unweighted cross entropy loss
|
||||
losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
|
||||
|
||||
# Apply the weights to the losses and compute their sum
|
||||
return (weights * losses).sum()
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def create_weighted_mask(labels: torch.Tensor):
|
||||
# Check if the tensor is 2D. If not, unsqueeze it to make it 2D
|
||||
if len(labels.shape) == 1:
|
||||
labels = labels.unsqueeze(0)
|
||||
|
||||
weights = torch.zeros_like(labels).float()
|
||||
for i in range(labels.shape[0]):
|
||||
mask = labels[i] != -100
|
||||
|
||||
# Create a tensor to track group ids
|
||||
group_ids = torch.zeros_like(labels[i]).int()
|
||||
curr_group_id = 0
|
||||
|
||||
for j in range(1, len(labels[i])):
|
||||
if mask[j] and not mask[j - 1]: # switch from masked to unmasked label
|
||||
curr_group_id += 1 # start new group
|
||||
group_ids[j] = (
|
||||
curr_group_id if mask[j] else 0
|
||||
) # assign group id if unmasked label
|
||||
|
||||
# Count only unmasked labels in each group
|
||||
group_counts = torch.bincount(group_ids[mask])
|
||||
|
||||
mask_weights = torch.zeros_like(labels[i]).float()
|
||||
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
|
||||
|
||||
weights[i] = mask_weights
|
||||
|
||||
return weights.squeeze() # squeeze the output to match the input dimension
|
||||
|
||||
|
||||
def trainer_weighted_loss(model_output, labels, shift_labels=True):
|
||||
logits = (
|
||||
model_output["logits"] if isinstance(model_output, dict) else model_output[0]
|
||||
)
|
||||
if shift_labels:
|
||||
logits = logits[..., :-1, :].contiguous()
|
||||
labels = labels[..., 1:].contiguous()
|
||||
|
||||
weights = create_weighted_mask(labels)
|
||||
return weighted_cross_entropy(logits, labels, weights)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(TrainingArguments):
|
||||
"""
|
||||
@@ -109,22 +39,6 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
sample_packing_seq_len_multiplier: int = field(
|
||||
default=1,
|
||||
metadata={"help": "the multiplier for the max len for packed sequences"},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
@@ -162,64 +76,6 @@ class AxolotlTrainer(Trainer):
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size > 1 and self.args.sample_packing:
|
||||
return DistributedSampler(
|
||||
self.train_dataset,
|
||||
num_replicas=self.args.world_size,
|
||||
rank=self.args.process_index,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
if self.args.sample_packing:
|
||||
train_sampler = self._get_train_sampler()
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
self.train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=train_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(
|
||||
self, eval_dataset: Optional[Dataset] = None
|
||||
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
if self.args.sample_packing:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
eval_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=eval_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# use one's weighted cross entropy loss calc
|
||||
# if self.args.sample_packing:
|
||||
# labels = inputs.pop("labels")
|
||||
# outputs = model(**inputs)
|
||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||
"""
|
||||
@@ -250,121 +106,10 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
def add_position_ids(sample):
|
||||
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
||||
return sample
|
||||
|
||||
|
||||
def drop_long_seq(sample, sequence_len=2048):
|
||||
return len(sample["input_ids"]) <= sequence_len
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_datasets_caching():
|
||||
try:
|
||||
set_caching_enabled(False)
|
||||
yield
|
||||
finally:
|
||||
set_caching_enabled(True)
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
if cfg.sample_packing:
|
||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
|
||||
add_position_ids, num_proc=os.cpu_count()
|
||||
)
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
|
||||
add_position_ids, num_proc=os.cpu_count()
|
||||
)
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
if cfg.sample_packing:
|
||||
# we have to drop anything longer then sequence len otherwise
|
||||
# flash attention with position ids fails
|
||||
if not cfg.total_num_tokens:
|
||||
LOG.info("calculating total_num_tokens")
|
||||
total_num_tokens = np.sum(
|
||||
train_dataset.data.column("input_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
||||
.values
|
||||
)
|
||||
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
|
||||
cfg.total_num_tokens = total_num_tokens
|
||||
|
||||
if cfg.sample_packing_eff_est:
|
||||
total_num_steps = (
|
||||
# match count to len est in dataloader
|
||||
(
|
||||
math.floor(
|
||||
0.99
|
||||
* cfg.total_num_tokens
|
||||
/ cfg.sample_packing_eff_est
|
||||
/ cfg.sequence_len
|
||||
// cfg.batch_size
|
||||
// int(os.environ.get("WORLD_SIZE", 1))
|
||||
)
|
||||
- 1
|
||||
)
|
||||
* cfg.num_epochs
|
||||
)
|
||||
LOG.info(
|
||||
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
|
||||
)
|
||||
else:
|
||||
sampler = RandomSampler(train_dataset)
|
||||
data_loader = MultipackDistributedDataloader(
|
||||
train_dataset,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
|
||||
collate_fn=DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
),
|
||||
sampler=sampler,
|
||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
data_loader_len = data_loader.len_w_stats()
|
||||
actual_eff = data_loader.efficiency()
|
||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||
total_num_steps = int(
|
||||
math.floor(
|
||||
data_loader_len
|
||||
* cfg.micro_batch_size
|
||||
* cfg.num_epochs
|
||||
// cfg.batch_size
|
||||
)
|
||||
)
|
||||
LOG.info(
|
||||
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
|
||||
)
|
||||
cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0
|
||||
else:
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
LOG.info(f"total_num_steps: {total_num_steps}")
|
||||
return total_num_steps
|
||||
|
||||
|
||||
def setup_fsdp_envs(cfg):
|
||||
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||
if cfg.fsdp_config.fsdp_sync_module_states:
|
||||
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
|
||||
if cfg.fsdp_config.fsdp_state_dict_type:
|
||||
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
|
||||
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||
if cfg.fsdp:
|
||||
setup_fsdp_envs(cfg)
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
warmup_steps = (
|
||||
cfg.warmup_steps
|
||||
if cfg.warmup_steps is not None
|
||||
@@ -444,14 +189,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
if cfg.save_safetensors:
|
||||
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
||||
|
||||
if cfg.sample_packing_eff_est:
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_efficiency"
|
||||
] = cfg.sample_packing_eff_est
|
||||
|
||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
||||
max_seq_length=cfg.sequence_len,
|
||||
per_device_train_batch_size=cfg.micro_batch_size,
|
||||
per_device_eval_batch_size=cfg.eval_batch_size
|
||||
if cfg.eval_batch_size is not None
|
||||
@@ -465,7 +203,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
||||
save_steps=cfg.save_steps,
|
||||
output_dir=cfg.output_dir,
|
||||
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
||||
save_total_limit=3,
|
||||
load_best_model_at_end=(
|
||||
cfg.load_best_model_at_end is not False
|
||||
and cfg.val_set_size > 0
|
||||
@@ -483,8 +221,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
||||
else "cosine",
|
||||
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
||||
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
|
||||
@@ -556,19 +292,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
||||
|
||||
callbacks = []
|
||||
callbacks.append(PrintGPUStatsCallback(cfg))
|
||||
|
||||
if cfg.relora_steps:
|
||||
relora_steps = int(cfg.relora_steps)
|
||||
relora_warmup_steps = int(cfg.relora_warmup_steps)
|
||||
callbacks.append(ReLoRACallback(cfg))
|
||||
|
||||
(optimizer, lr_scheduler) = trainer_kwargs["optimizers"]
|
||||
trainer_kwargs["optimizers"] = (
|
||||
optimizer,
|
||||
ReLoRAScheduler(optimizer, lr_scheduler, relora_steps, relora_warmup_steps),
|
||||
)
|
||||
|
||||
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
||||
if cfg.early_stopping_patience:
|
||||
early_stop_cb = EarlyStoppingCallback(
|
||||
@@ -591,11 +314,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
if cfg.collator_pad_to_longest:
|
||||
data_collator_kwargs["padding"] = "longest"
|
||||
else:
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 8
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||
from functools import partial
|
||||
|
||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||
add_mem_tokens,
|
||||
get_mem_id,
|
||||
@@ -623,7 +346,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
args=training_args,
|
||||
data_collator=DataCollatorForSeq2Seq(
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
|
||||
@@ -8,19 +8,6 @@ LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def validate_config(cfg):
|
||||
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
||||
raise ValueError(
|
||||
"please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
|
||||
)
|
||||
if cfg.max_packed_sequence_len:
|
||||
LOG.warning(
|
||||
str(
|
||||
PendingDeprecationWarning(
|
||||
"max_packed_sequence_len will be deprecated in favor of sample_packing"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
||||
raise ValueError(
|
||||
"please set only one of gradient_accumulation_steps or batch_size"
|
||||
@@ -61,9 +48,6 @@ def validate_config(cfg):
|
||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||
|
||||
if cfg.relora_steps and cfg.adapter not in ("lora", "qlora"):
|
||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||
|
||||
if cfg.trust_remote_code:
|
||||
LOG.warning(
|
||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||
@@ -113,24 +97,6 @@ def validate_config(cfg):
|
||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||
)
|
||||
|
||||
if cfg.gptq and cfg.model_revision:
|
||||
raise ValueError(
|
||||
"model_revision is not supported for GPTQ models. "
|
||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
||||
+ "point to its path, and remove model_revision from the config."
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.sdp_attention:
|
||||
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.xformers_attention:
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
|
||||
@@ -9,8 +9,6 @@ def setup_wandb_env_vars(cfg):
|
||||
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
||||
cfg.use_wandb = True
|
||||
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
|
||||
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
|
||||
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
||||
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
||||
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,30 +0,0 @@
|
||||
"""
|
||||
Unit tests for the monkeypatch utils
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids
|
||||
|
||||
|
||||
class TestMonkeyPatchUtils(unittest.TestCase):
|
||||
"""
|
||||
Unit test class for monkeypatch utils
|
||||
"""
|
||||
|
||||
def test_get_cu_seqlens_1d(self):
|
||||
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
||||
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
|
||||
self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res))
|
||||
|
||||
def test_get_cu_seqlens_from_pos_ids_1d(self):
|
||||
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]])
|
||||
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
|
||||
self.assertTrue(
|
||||
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,44 +0,0 @@
|
||||
"""
|
||||
Unit tests for the monkey patch for expand mask to handle packed sequences
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.monkeypatch.llama_expand_mask import _expand_mask
|
||||
|
||||
|
||||
class TestExpandMask(unittest.TestCase):
|
||||
"""
|
||||
Test class for attention mask expansion for packed sequences
|
||||
"""
|
||||
|
||||
def test_output(self):
|
||||
mask = torch.tensor([[1, 1, 1, 2], [2, 3, 3, 0]])
|
||||
dtype = torch.float32
|
||||
expected_output = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
|
||||
[-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||
[-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
|
||||
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||
]
|
||||
],
|
||||
]
|
||||
)
|
||||
# Check that the output matches the expected output
|
||||
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -27,7 +27,7 @@ class TestPacking(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
def test_increments_attention(self):
|
||||
def test_resets_attention(self):
|
||||
prompter = AlpacaPrompter("chat")
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
@@ -55,14 +55,10 @@ class TestPacking(unittest.TestCase):
|
||||
# first example doesn't have mask reset
|
||||
assert example["input_ids"][0] == self.tokenizer.bos_token_id
|
||||
assert example["attention_mask"][0] == 1
|
||||
assert example["position_ids"][0] == 0
|
||||
assert example["position_ids"][1] == 1
|
||||
|
||||
# but subsequent one does
|
||||
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
|
||||
assert example["attention_mask"][next_bos_index] == 2
|
||||
assert example["position_ids"][next_bos_index] == 0
|
||||
assert example["position_ids"][next_bos_index + 1] == 1
|
||||
assert example["attention_mask"][next_bos_index] == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -4,17 +4,13 @@ import logging
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import AutoTokenizer, LlamaTokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
||||
from axolotl.prompt_strategies.alpaca_w_system import (
|
||||
InstructionWSystemPromptTokenizingStrategy,
|
||||
SystemDataPrompter,
|
||||
)
|
||||
from axolotl.prompt_strategies.llama2_chat import (
|
||||
Llama2ChatPrompter,
|
||||
LLama2ChatTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
@@ -134,95 +130,9 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
||||
"output": "Hi! How can I help?",
|
||||
}
|
||||
example = strat.tokenize_prompt(sample)
|
||||
assert example["input_ids"][0:5] == [
|
||||
1,
|
||||
28962,
|
||||
1254,
|
||||
12665,
|
||||
29901,
|
||||
] # "<s>SYSTEM:"
|
||||
assert example["input_ids"][5:7] == [671, 20118] # " use cot"
|
||||
assert example["input_ids"][8] == 11889 # USER
|
||||
|
||||
|
||||
class Llama2ChatTokenizationTest(unittest.TestCase):
|
||||
"""
|
||||
Test class for prompt tokenization strategies with sys prompt from the dataset
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
||||
# woraround because official Meta repos are not open
|
||||
|
||||
def test_llama2_chat_integration(self):
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
||||
) as fin:
|
||||
data = fin.read()
|
||||
conversation = json.loads(data)
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.tokenized_llama2chat.json",
|
||||
encoding="utf-8",
|
||||
) as fin:
|
||||
data = fin.read()
|
||||
tokenized_conversation = json.loads(data)
|
||||
prompter = Llama2ChatPrompter()
|
||||
strat = LLama2ChatTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
4096,
|
||||
)
|
||||
example = strat.tokenize_prompt(conversation)
|
||||
for fields in ["input_ids", "attention_mask", "labels"]:
|
||||
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
||||
self.assertEqual(example[fields], tokenized_conversation[fields])
|
||||
|
||||
def compare_with_transformers_integration(self):
|
||||
# this needs transformers >= v4.31.0
|
||||
from transformers.models.llama.tokenization_llama import B_SYS, E_SYS
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
# from transformers.models.llama.tokenization_llama import DEFAULT_SYSTEM_PROMPT
|
||||
# broken as of 23/7/20
|
||||
# see https://github.com/huggingface/transformers/pull/24935
|
||||
# pylint: disable=C0103
|
||||
DEFAULT_SYSTEM_PROMPT = """\
|
||||
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
||||
|
||||
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
||||
) as fin:
|
||||
data = fin.read()
|
||||
conversation = json.loads(data)
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.tokenized_llama2chat.json",
|
||||
encoding="utf-8",
|
||||
) as fin:
|
||||
data = fin.read()
|
||||
tokenized_conversation = json.loads(data)
|
||||
|
||||
user_input = []
|
||||
answers = []
|
||||
for msg in conversation["conversations"]:
|
||||
if msg["from"] == "human":
|
||||
user_input.append(msg["value"])
|
||||
else:
|
||||
answers.append(msg["value"])
|
||||
hf_conf = Conversation(
|
||||
text=user_input[-1],
|
||||
past_user_inputs=[B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + user_input[0]]
|
||||
+ user_input[1:-1],
|
||||
generated_responses=answers,
|
||||
)
|
||||
# pylint: disable=W0212
|
||||
hf_tokens = self.tokenizer._build_conversation_input_ids(hf_conf)
|
||||
|
||||
self.assertEqual(
|
||||
hf_tokens, tokenized_conversation["input_ids"][: len(hf_tokens)]
|
||||
)
|
||||
assert example["input_ids"][0:4] == [1, 835, 2184, 29901] # "<s>### System:"
|
||||
assert example["input_ids"][5:7] == [1509, 20118] # "use cot"
|
||||
assert example["input_ids"][9] == 11889 # USER
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -70,7 +70,7 @@ class AlpacaPrompterTest(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
assert "use cot" in res
|
||||
assert res.startswith("SYSTEM:")
|
||||
assert res.startswith("### System:")
|
||||
assert "### Instruction:" not in res
|
||||
assert "### Input:" not in res
|
||||
assert "alpacas" in res
|
||||
|
||||
@@ -313,27 +313,3 @@ class ValidationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_packing(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"max_packed_sequence_len": 2048,
|
||||
}
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"max_packed_sequence_len will be deprecated in favor of sample_packing"
|
||||
in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"max_packed_sequence_len": 2048,
|
||||
"sample_packing": True,
|
||||
}
|
||||
)
|
||||
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
|
||||
with pytest.raises(ValueError, match=regex_exp):
|
||||
validate_config(cfg)
|
||||
|
||||
Reference in New Issue
Block a user