Compare commits
1 Commits
feat/space
...
sdpa-multi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a538be9c2 |
2
.github/FUNDING.yml
vendored
2
.github/FUNDING.yml
vendored
@@ -1,6 +1,6 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: [winglian, OpenAccess-AI-Collective] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||
github: OpenAccess-AI-Collective # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||
patreon: # Replace with a single Patreon username
|
||||
open_collective: # Replace with a single Open Collective username
|
||||
ko_fi: axolotl_ai # Replace with a single Ko-fi username
|
||||
|
||||
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -73,7 +73,7 @@ jobs:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.2
|
||||
pytorch: 2.1.1
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
26
README.md
26
README.md
@@ -37,9 +37,6 @@ Features:
|
||||
- [Inference](#inference)
|
||||
- [Merge LORA to Base](#merge-lora-to-base)
|
||||
- [Special Tokens](#special-tokens)
|
||||
- Advanced Topics
|
||||
- [Multipack](./docs/multipack.md)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||
- [RLHF & DPO](./docs/rlhf.md)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||
- [Common Errors](#common-errors-)
|
||||
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
||||
- [Debugging Axolotl](#debugging-axolotl)
|
||||
@@ -610,17 +607,6 @@ datasets:
|
||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||
field:
|
||||
|
||||
# A list of one or more datasets to eval the model with.
|
||||
# You can use either test_datasets, or val_set_size, but not both.
|
||||
test_datasets:
|
||||
- path: /workspace/data/eval.jsonl
|
||||
ds_type: json
|
||||
# You need to specify a split. For "json" datasets the default split is called "train".
|
||||
split: train
|
||||
type: completion
|
||||
data_files:
|
||||
- /workspace/data/eval.jsonl
|
||||
|
||||
# use RL training: dpo, ipo, kto_pair
|
||||
rl:
|
||||
|
||||
@@ -710,12 +696,6 @@ lora_modules_to_save:
|
||||
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
peft:
|
||||
# Configuration options for loftq initialization for LoRA
|
||||
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
|
||||
loftq_config:
|
||||
loftq_bits: # typically 4 bits
|
||||
|
||||
# ReLoRA configuration
|
||||
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||
relora_steps: # Number of steps per ReLoRA restart
|
||||
@@ -1155,11 +1135,9 @@ Having misalignment between your prompts during training and inference can cause
|
||||
|
||||
See [this debugging guide](docs/debugging.md) for tips on debugging Axolotl, along with an example configuration for debugging with VSCode.
|
||||
|
||||
## Need help? 🙋
|
||||
## Need help? 🙋♂️
|
||||
|
||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we our community members can help you.
|
||||
|
||||
Need dedicated support? Please contact us at [✉️wing@openaccessaicollective.org](mailto:wing@openaccessaicollective.org) for dedicated support options.
|
||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
||||
|
||||
## Badge ❤🏷️
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ EXPOSE 8888
|
||||
EXPOSE 22
|
||||
|
||||
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
|
||||
COPY scripts/motd /etc/motd
|
||||
|
||||
RUN pip install jupyterlab notebook ipywidgets && \
|
||||
jupyter lab clean
|
||||
@@ -19,7 +18,6 @@ RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
||||
mkdir -p ~/.ssh && \
|
||||
chmod 700 ~/.ssh && \
|
||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
||||
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
|
||||
chmod +x /root/cloud-entrypoint.sh
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 239 KiB |
@@ -1,11 +1,4 @@
|
||||
# Multipack (Sample Packing)
|
||||
|
||||
## Visualization of Multipack with Flash Attention
|
||||
|
||||
Because Flash Attention simply drops the attention mask, we do not need to
|
||||
construct a 4d attention mask. We only need to concatenate the sequences into
|
||||
a single batch and let flash attention know where each new sequence begins.
|
||||
|
||||
# Multipack
|
||||
|
||||
4k context, bsz =4,
|
||||
each character represents 256 tokens
|
||||
@@ -56,18 +49,3 @@ w packing ( note it's the same effective number of tokens per step, but a true b
|
||||
E E E E F F F F F G G G H H H H
|
||||
I I I J J J J K K K K K L L L X ]]
|
||||
```
|
||||
|
||||
cu_seqlens:
|
||||
[[ 0, 11, 17, 24, 28, 36, 41 44, 48, 51, 55, 60, 64]]
|
||||
|
||||
|
||||
## Multipack without Flash Attention
|
||||
|
||||
Multipack can still be achieved without Flash attention, but with lower packing
|
||||
efficiency as we are not able to join multiple batches into a single batch due to
|
||||
context length limits without flash attention. We can use either Pytorch's Scaled
|
||||
Dot Product Attention implementation or native Pytorch attention implementation
|
||||
along with [4d attention masks](https://github.com/huggingface/transformers/pull/27539)
|
||||
to pack sequences together and avoid cross attention.
|
||||
|
||||
<img src="./images/4d-mask.png" alt="axolotl" width="800">
|
||||
|
||||
@@ -12,8 +12,8 @@ feedback. Various methods include, but not limited to:
|
||||
|
||||
### RLHF using Axolotl
|
||||
|
||||
>[!IMPORTANT]
|
||||
>This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
||||
[!IMPORTANT]
|
||||
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
||||
|
||||
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ val_set_size: 0.05
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -43,7 +43,6 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install torch==\"2.1.2\"\n",
|
||||
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
|
||||
"!pip install flash-attn==\"2.5.0\"\n",
|
||||
"!pip install deepspeed==\"0.13.1\""
|
||||
|
||||
@@ -67,3 +67,6 @@ weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
peft:
|
||||
loftq_config:
|
||||
loftq_bits: 4
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -65,3 +65,6 @@ weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
|
||||
@@ -65,3 +65,6 @@ weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
|
||||
@@ -12,7 +12,6 @@ max_steps: 200
|
||||
pretraining_dataset:
|
||||
path: c4
|
||||
name: en
|
||||
type: pretrain
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./model-out
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc
|
||||
peft==0.7.1
|
||||
transformers==4.37.0
|
||||
tokenizers==0.15.0
|
||||
bitsandbytes>=0.41.1
|
||||
accelerate==0.26.1
|
||||
|
||||
17
scripts/motd
17
scripts/motd
@@ -1,17 +0,0 @@
|
||||
|
||||
dP dP dP
|
||||
88 88 88
|
||||
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
|
||||
88' `88 `8bd8' 88' `88 88 88' `88 88 88
|
||||
88. .88 .d88b. 88. .88 88 88. .88 88 88
|
||||
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
|
||||
|
||||
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:
|
||||
|
||||
```
|
||||
cd /workspace
|
||||
rm -rf /workspace/axolotl
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||
cd axolotl
|
||||
pip install --no-deps -e .
|
||||
```
|
||||
3
setup.py
3
setup.py
@@ -27,7 +27,6 @@ def parse_requirements():
|
||||
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
if torch_version.startswith("2.1."):
|
||||
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
||||
_install_requires.append("xformers>=0.0.23")
|
||||
@@ -51,7 +50,7 @@ setup(
|
||||
dependency_links=dependency_links,
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn==2.5.0",
|
||||
"flash-attn==2.3.3",
|
||||
],
|
||||
"fused-dense-lib": [
|
||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
|
||||
|
||||
@@ -6,9 +6,8 @@ from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import fire
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
import transformers
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
@@ -28,7 +27,7 @@ LOG = logging.getLogger("axolotl.cli.train")
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser((TrainerCliArgs))
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
@@ -59,22 +59,6 @@ except ImportError:
|
||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||
|
||||
|
||||
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||
if isinstance(tag_names, str):
|
||||
tag_names = [tag_names]
|
||||
|
||||
if kwargs is not None:
|
||||
if "tags" not in kwargs:
|
||||
kwargs["tags"] = tag_names
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||
kwargs["tags"].extend(tag_names)
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||
tag_names.append(kwargs["tags"])
|
||||
kwargs["tags"] = tag_names
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(TrainingArguments):
|
||||
"""
|
||||
@@ -98,10 +82,6 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
@@ -126,10 +106,6 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
@@ -237,19 +213,11 @@ class AxolotlTrainer(Trainer):
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_train_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
batch_max_len = (
|
||||
self.args.per_device_train_batch_size * self.args.max_seq_length
|
||||
)
|
||||
return MultipackBatchSampler(
|
||||
RandomSampler(self.train_dataset),
|
||||
batch_size=batch_size,
|
||||
self.args.train_batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_max_len=self._train_batch_size * self.args.max_seq_length,
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
@@ -259,19 +227,11 @@ class AxolotlTrainer(Trainer):
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_eval_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
batch_max_len = (
|
||||
self.args.per_device_eval_batch_size * self.args.max_seq_length
|
||||
)
|
||||
return MultipackBatchSampler(
|
||||
SequentialSampler(eval_dataset),
|
||||
batch_size=batch_size,
|
||||
self.args.per_device_eval_batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
|
||||
lengths=get_dataset_lengths(eval_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
@@ -389,13 +349,30 @@ class AxolotlTrainer(Trainer):
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
|
||||
def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
|
||||
if isinstance(tag_names, str):
|
||||
tag_names = [tag_names]
|
||||
|
||||
if kwargs is not None:
|
||||
if "tags" not in kwargs:
|
||||
kwargs["tags"] = tag_names
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||
kwargs["tags"].extend(tag_names)
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||
tag_names.append(kwargs["tags"])
|
||||
kwargs["tags"] = tag_names
|
||||
|
||||
return kwargs
|
||||
|
||||
@wraps(Trainer.push_to_hub)
|
||||
def push_to_hub(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
kwargs = self._sanitize_kwargs_for_tagging(
|
||||
tag_names=self.tag_names, kwargs=kwargs
|
||||
)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@@ -482,14 +459,10 @@ class ReLoRATrainer(AxolotlTrainer):
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
anneal_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
@@ -498,24 +471,6 @@ class ReLoRATrainer(AxolotlTrainer):
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(DPOTrainer):
|
||||
"""
|
||||
Extend the base DPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "dpo"]
|
||||
|
||||
@wraps(DPOTrainer.push_to_hub)
|
||||
def push_to_hub(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""
|
||||
Base class for trainer builder
|
||||
@@ -763,7 +718,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
||||
training_arguments_kwargs["dataloader_drop_last"] = True
|
||||
|
||||
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
||||
if self.cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||
elif self.cfg.eval_steps:
|
||||
@@ -850,7 +805,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.load_best_model_at_end is not False
|
||||
or self.cfg.early_stopping_patience
|
||||
)
|
||||
and not self.cfg.test_datasets
|
||||
and self.cfg.val_set_size > 0
|
||||
and self.cfg.save_steps
|
||||
and self.cfg.eval_steps
|
||||
@@ -888,9 +842,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["sample_packing"] = (
|
||||
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||
)
|
||||
training_arguments_kwargs["multipack_real_batches"] = (
|
||||
self.cfg.flash_attention is not True
|
||||
)
|
||||
training_arguments_kwargs["eval_sample_packing"] = (
|
||||
self.cfg.sample_packing
|
||||
if self.cfg.eval_sample_packing is not False
|
||||
@@ -901,7 +852,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
] = self.cfg.micro_batch_size
|
||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
||||
training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps
|
||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||
training_arguments_kwargs
|
||||
)
|
||||
@@ -996,11 +946,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if use_batch_sampler_collator:
|
||||
if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
elif (
|
||||
self.cfg.model_config_type in ["llama"]
|
||||
and self.cfg.flash_attention is not True
|
||||
):
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
@@ -1096,21 +1041,13 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
||||
"use_reentrant": False
|
||||
}
|
||||
|
||||
# set save_strategy and save_steps
|
||||
if self.cfg.save_steps:
|
||||
training_args_kwargs["save_strategy"] = "steps"
|
||||
training_args_kwargs["save_steps"] = self.cfg.save_steps
|
||||
elif self.cfg.save_strategy:
|
||||
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
|
||||
else:
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
max_steps=self.cfg.max_steps or total_num_steps,
|
||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||
learning_rate=self.cfg.learning_rate,
|
||||
save_strategy="steps",
|
||||
save_steps=self.cfg.save_steps,
|
||||
output_dir=self.cfg.output_dir,
|
||||
warmup_steps=self.cfg.warmup_steps,
|
||||
logging_first_step=True,
|
||||
@@ -1139,7 +1076,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs[
|
||||
"precompute_ref_log_probs"
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
dpo_trainer = AxolotlDPOTrainer(
|
||||
dpo_trainer = DPOTrainer(
|
||||
self.model,
|
||||
self.model_ref,
|
||||
args=training_args,
|
||||
|
||||
@@ -31,7 +31,7 @@ class TokenizedPromptDataset(Dataset):
|
||||
def __init__( # pylint: disable=super-init-not-called
|
||||
self,
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: Dataset,
|
||||
dataset: IterableDataset,
|
||||
process_count: Optional[int] = None,
|
||||
keep_in_memory: Optional[bool] = False,
|
||||
**kwargs,
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
|
||||
# pylint: disable=protected-access
|
||||
|
||||
import torch
|
||||
from torch.utils.data._utils.fetch import _BaseDatasetFetcher
|
||||
from torch.utils.data._utils.worker import _worker_loop
|
||||
|
||||
|
||||
class _MapDatasetFetcher(_BaseDatasetFetcher):
|
||||
def fetch(self, possibly_batched_index):
|
||||
if isinstance(possibly_batched_index[0], list):
|
||||
data = [None for i in possibly_batched_index]
|
||||
for i, possibly_batched_index_ in enumerate(possibly_batched_index):
|
||||
if self.auto_collation:
|
||||
if (
|
||||
hasattr(self.dataset, "__getitems__")
|
||||
and self.dataset.__getitems__
|
||||
):
|
||||
data[i] = self.dataset.__getitems__(possibly_batched_index_)
|
||||
else:
|
||||
data[i] = [self.dataset[idx] for idx in possibly_batched_index_]
|
||||
else:
|
||||
data[i] = self.dataset[possibly_batched_index_]
|
||||
else:
|
||||
if self.auto_collation:
|
||||
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
|
||||
data = self.dataset.__getitems__(possibly_batched_index)
|
||||
else:
|
||||
data = [self.dataset[idx] for idx in possibly_batched_index]
|
||||
else:
|
||||
data = self.dataset[possibly_batched_index]
|
||||
return self.collate_fn(data)
|
||||
|
||||
|
||||
def patch_fetchers():
|
||||
torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
||||
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
||||
|
||||
|
||||
def patched_worker_loop(*args, **kwargs):
|
||||
patch_fetchers()
|
||||
return _worker_loop(*args, **kwargs)
|
||||
|
||||
|
||||
torch.utils.data._utils.worker._worker_loop = patched_worker_loop
|
||||
patch_fetchers()
|
||||
142
src/axolotl/monkeypatch/llama_attn_hijack_sdp.py
Normal file
142
src/axolotl/monkeypatch/llama_attn_hijack_sdp.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers.models.llama.modeling_llama
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
|
||||
def hijack_llama_sdp_attention():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||
sdp_attention_forward
|
||||
)
|
||||
|
||||
|
||||
def sdp_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if not hasattr(self, "pretraining_tp"):
|
||||
self.pretraining_tp = 1
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
key_value_slicing = (
|
||||
self.num_key_value_heads * self.head_dim
|
||||
) // self.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [
|
||||
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [
|
||||
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [
|
||||
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
# [bsz, q_len, nh, hd]
|
||||
# [bsz, nh, q_len, hd]
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if output_attentions:
|
||||
warnings.warn(
|
||||
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
||||
)
|
||||
|
||||
#
|
||||
# sdp-attn start
|
||||
#
|
||||
|
||||
with torch.backends.cuda.sdp_kernel():
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
#
|
||||
# sdp-attn end
|
||||
#
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(
|
||||
self.hidden_size // self.pretraining_tp, dim=1
|
||||
)
|
||||
attn_output = sum(
|
||||
F.linear(attn_output[i], o_proj_slices[i])
|
||||
for i in range(self.pretraining_tp)
|
||||
)
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
@@ -5,11 +5,38 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.monkeypatch.utils import mask_2d_to_4d
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len)
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||
when they attend to each other within that sequence.
|
||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
||||
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
binary_mask = torch.where(
|
||||
mask != 0,
|
||||
torch.tensor(1).to(dtype),
|
||||
torch.tensor(0).to(dtype),
|
||||
)
|
||||
|
||||
# Create a block-diagonal mask.
|
||||
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||
|
||||
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
||||
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
||||
mask.device
|
||||
)
|
||||
|
||||
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
||||
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
||||
inverted_mask = 1.0 - masked_zero_one_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
"""
|
||||
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
|
||||
"""
|
||||
|
||||
from axolotl.monkeypatch.utils import (
|
||||
patched_prepare_4d_causal_attention_mask,
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
|
||||
|
||||
def hijack_llama_prepare_4d_mask():
|
||||
import transformers.modeling_attn_mask_utils
|
||||
import transformers.models.llama.modeling_llama
|
||||
|
||||
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||
)
|
||||
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||
)
|
||||
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask
|
||||
)
|
||||
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask
|
||||
)
|
||||
@@ -94,7 +94,7 @@ def _prepare_decoder_attention_mask(
|
||||
sliding_window,
|
||||
): # pylint: disable=unused-argument
|
||||
# [bsz, seq_len]
|
||||
if attention_mask is None or sliding_window is None:
|
||||
if attention_mask is None:
|
||||
return attention_mask
|
||||
|
||||
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
||||
@@ -151,7 +151,7 @@ def flashattn_forward(
|
||||
)
|
||||
|
||||
use_sliding_windows = (
|
||||
getattr(self.config, "sliding_window") is not None
|
||||
hasattr(self.config, "sliding_window") is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
)
|
||||
|
||||
|
||||
@@ -4,16 +4,14 @@ import json
|
||||
import logging
|
||||
import os.path
|
||||
import shutil
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Sequence, Union
|
||||
from typing import Dict, List, Sequence
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import peft
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from transformers import (
|
||||
@@ -25,50 +23,23 @@ from transformers import (
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import barrier, is_main_process
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
LOG = logging.getLogger("axolotl.relora")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def magnitude_pruning_(tensor, prune_ratio):
|
||||
tensor_magnitude = torch.abs(tensor)
|
||||
threshold = torch.quantile(
|
||||
tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio
|
||||
).to(dtype=tensor.dtype)
|
||||
def reset_optimizer(optimizer: torch.optim.Optimizer):
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
param_state = optimizer.state[param]
|
||||
for key in param_state:
|
||||
if "qmap" in key:
|
||||
continue
|
||||
|
||||
mask = tensor_magnitude > threshold
|
||||
tensor.mul_(mask.to(dtype=tensor.dtype))
|
||||
|
||||
|
||||
def reset_optimizer(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
*,
|
||||
reset_params: list[str], # where str is the key to a torch.nn.Parameter
|
||||
optimizer_state_keys: list[str],
|
||||
):
|
||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9)
|
||||
n_zeros = 0
|
||||
n_total = 0
|
||||
|
||||
optimizer_state = optimizer.state
|
||||
if isinstance(optimizer, ZeroRedundancyOptimizer):
|
||||
optimizer_state = optimizer.optim.state
|
||||
|
||||
for param in reset_params:
|
||||
param_state = optimizer_state[param]
|
||||
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
|
||||
continue
|
||||
for key in optimizer_state_keys:
|
||||
pruning_fn(
|
||||
param_state[key]
|
||||
) # pruning fn has to be inplace to keep the same keys in the dict
|
||||
n_total += param_state[key].numel()
|
||||
n_zeros += torch.sum(param_state[key] == 0).item()
|
||||
|
||||
_zeroed = n_zeros / (1e-7 + n_total) * 100
|
||||
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
|
||||
LOG.info(f"absolute n of optimizer states zeroed: {n_zeros}")
|
||||
if key == "step" and isinstance(param_state[key], int):
|
||||
param_state[key] = 0
|
||||
else:
|
||||
param_state[key] = torch.zeros_like(param_state[key])
|
||||
|
||||
|
||||
class ReLoRACallback(TrainerCallback):
|
||||
@@ -126,25 +97,6 @@ class ReLoRACallback(TrainerCallback):
|
||||
"relora",
|
||||
)
|
||||
|
||||
if "adam" in args.optim.lower():
|
||||
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
|
||||
else:
|
||||
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
|
||||
|
||||
lora_params = [
|
||||
n
|
||||
for n, p in model.named_parameters()
|
||||
if p.requires_grad and "lora_" in n
|
||||
]
|
||||
|
||||
model.save_pretrained(
|
||||
os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
"adapter",
|
||||
),
|
||||
safe_serialization=True,
|
||||
)
|
||||
with torch.no_grad():
|
||||
merge_and_save(
|
||||
model,
|
||||
@@ -155,11 +107,7 @@ class ReLoRACallback(TrainerCallback):
|
||||
actually_save=is_main_process(),
|
||||
cpu_offload=self.cpu_offload,
|
||||
)
|
||||
reset_optimizer(
|
||||
optimizer,
|
||||
reset_params=lora_params,
|
||||
optimizer_state_keys=optimizer_state_keys,
|
||||
)
|
||||
reset_optimizer(optimizer)
|
||||
|
||||
if self.quantized:
|
||||
self.last_full_model = checkpoint_folder
|
||||
@@ -249,13 +197,11 @@ class ReLoRAScheduler(LRScheduler):
|
||||
inner_schedule: LRScheduler,
|
||||
relora_steps: int,
|
||||
warmup_steps: int,
|
||||
anneal_steps: int = 1,
|
||||
min_lr_scale: float = 0.001,
|
||||
) -> None:
|
||||
self.inner_schedule = inner_schedule
|
||||
self.relora_steps = relora_steps
|
||||
self.warmup_steps = warmup_steps
|
||||
self.anneal_steps = anneal_steps
|
||||
self.min_lr_scale = min_lr_scale
|
||||
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
||||
|
||||
@@ -264,20 +210,10 @@ class ReLoRAScheduler(LRScheduler):
|
||||
|
||||
original = self.inner_schedule.get_lr()
|
||||
step = self.last_epoch
|
||||
|
||||
if step < self.relora_steps:
|
||||
scale = 1
|
||||
else:
|
||||
per_relora_progress = step % self.relora_steps
|
||||
if per_relora_progress < self.warmup_steps:
|
||||
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
|
||||
elif per_relora_progress > (self.relora_steps - self.anneal_steps):
|
||||
cycle_t = min(
|
||||
1.0,
|
||||
(self.relora_steps - per_relora_progress) / self.anneal_steps,
|
||||
)
|
||||
else:
|
||||
cycle_t = 1
|
||||
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
|
||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||
|
||||
if isinstance(original, Sequence):
|
||||
@@ -302,11 +238,7 @@ def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
|
||||
|
||||
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
|
||||
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
|
||||
adapter: Union[List[str], str] = layer.active_adapter
|
||||
if isinstance(adapter, list):
|
||||
if len(adapter) > 1:
|
||||
raise ValueError("unhandled relora for multiple adapters")
|
||||
adapter = adapter[0]
|
||||
adapter = layer.active_adapter
|
||||
return (
|
||||
peft.utils.transpose(
|
||||
layer.lora_B[adapter].weight.detach().to(device)
|
||||
@@ -316,7 +248,7 @@ def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor
|
||||
* layer.scaling[adapter]
|
||||
)
|
||||
|
||||
raise ValueError("unhandled lora layer type")
|
||||
return layer.get_delta_weight().to(device)
|
||||
|
||||
|
||||
def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
|
||||
@@ -341,9 +273,9 @@ def update_weights(
|
||||
):
|
||||
if reinit:
|
||||
for adapter_name in target.lora_A:
|
||||
target.reset_lora_parameters(adapter_name, True)
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
for adapter_name in target.lora_embedding_A:
|
||||
target.reset_lora_parameters(adapter_name, True)
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
|
||||
if isinstance(target, peft.tuners.lora.Linear4bit):
|
||||
# This could be faster, but the quantization of Linear4bit weights occurs
|
||||
@@ -354,9 +286,7 @@ def update_weights(
|
||||
target.weight.data = new_weight.cpu()
|
||||
target.to(device)
|
||||
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
||||
target.weight.data = (
|
||||
bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data
|
||||
)
|
||||
target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device)
|
||||
else:
|
||||
target.weight.data = new_weight.to(device)
|
||||
|
||||
@@ -374,17 +304,14 @@ def merge_and_save(
|
||||
|
||||
if not quantized:
|
||||
for module_name, target in modules.items():
|
||||
active_adapter = target.active_adapter
|
||||
if isinstance(active_adapter, list):
|
||||
active_adapter = active_adapter[0]
|
||||
update = target.get_delta_weight(active_adapter).detach()
|
||||
update = target.get_delta_weight(target.active_adapter).detach()
|
||||
target.weight.data += update
|
||||
|
||||
if reinit:
|
||||
for adapter_name in target.lora_A:
|
||||
target.reset_lora_parameters(adapter_name, True)
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
for adapter_name in target.lora_embedding_A:
|
||||
target.reset_lora_parameters(adapter_name, True)
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
return
|
||||
|
||||
os.makedirs(model_dst, exist_ok=True)
|
||||
@@ -436,7 +363,6 @@ def merge_and_save(
|
||||
LOG.info(f"saving tensors to {shard_fn}")
|
||||
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
|
||||
|
||||
barrier()
|
||||
del in_tensors
|
||||
del out_tensors
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -1,15 +1,8 @@
|
||||
"""
|
||||
Shared utils for the monkeypatches
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@@ -96,6 +89,7 @@ def get_cu_seqlens(attn_mask):
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
||||
if len(position_ids.shape) == 1:
|
||||
@@ -141,18 +135,7 @@ def get_cu_seqlens_from_pos_ids(position_ids):
|
||||
results.append(cu_seqlens)
|
||||
max_seq_lens.append(max_seq_len)
|
||||
|
||||
# Find the maximum value across all tensors
|
||||
max_value = max(t.max() for t in results)
|
||||
|
||||
# Find the length of the longest tensor
|
||||
max_length = max(t.size(0) for t in results)
|
||||
|
||||
# Pad each tensor to the same length and collect them in a list
|
||||
padded_results = [
|
||||
F.pad(t, (0, max_length - t.size(0)), "constant", max_value) for t in results
|
||||
]
|
||||
|
||||
return torch.stack(padded_results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
|
||||
|
||||
def set_module_name(model, name, value):
|
||||
@@ -166,62 +149,3 @@ def set_module_name(model, name, value):
|
||||
child_name = name
|
||||
|
||||
setattr(parent, child_name, value)
|
||||
|
||||
|
||||
def mask_2d_to_4d(
|
||||
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||
when they attend to each other within that sequence.
|
||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
||||
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
binary_mask = torch.where(
|
||||
mask != 0,
|
||||
torch.tensor(1).to(dtype),
|
||||
torch.tensor(0).to(dtype),
|
||||
)
|
||||
|
||||
# Create a block-diagonal mask.
|
||||
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||
|
||||
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
||||
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
||||
mask.device
|
||||
)
|
||||
|
||||
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
||||
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
||||
|
||||
return masked_zero_one_mask
|
||||
|
||||
|
||||
def patched_prepare_4d_causal_attention_mask(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
*args,
|
||||
):
|
||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||
return _prepare_4d_causal_attention_mask(
|
||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def patched_prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
*args,
|
||||
):
|
||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||
return _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||
*args,
|
||||
)
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
"""Module containing the InstructShareGPTPromptTokenizingStrategy class"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||
)
|
||||
strategy = InstructShareGPTPromptTokenizingStrategy(
|
||||
# pylint: disable=duplicate-code
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
return strategy
|
||||
|
||||
|
||||
class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return [
|
||||
{"from": "human", "value": prompt["instruction"]},
|
||||
{"from": "gpt", "value": prompt["output"]},
|
||||
]
|
||||
@@ -1,58 +0,0 @@
|
||||
"""pretraining prompt strategies"""
|
||||
from typing import Generator
|
||||
|
||||
from transformers import BatchEncoding
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
|
||||
|
||||
class PretrainTokenizer:
|
||||
"""basic tokenization class for pretraining"""
|
||||
|
||||
def build_prompt(self, prompt) -> Generator[str, None, None]:
|
||||
yield prompt
|
||||
|
||||
|
||||
class PretrainTokenizationStrategy(PromptTokenizingStrategy):
|
||||
"""handles tokenization for pretraining with strides"""
|
||||
|
||||
@property
|
||||
def supports_batched(self):
|
||||
return True
|
||||
|
||||
def __init__(self, *args, max_length=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if max_length:
|
||||
self.max_length = max_length
|
||||
|
||||
def _tokenize(
|
||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||
) -> BatchEncoding:
|
||||
res = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.max_length - 1,
|
||||
add_special_tokens=True,
|
||||
return_overflowing_tokens=True,
|
||||
stride=256,
|
||||
)
|
||||
res["input_ids"] = [
|
||||
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
|
||||
]
|
||||
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
|
||||
|
||||
return res
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
return self._tokenize(prompt["text"])
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
strat = PretrainTokenizationStrategy(
|
||||
PretrainTokenizer(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
max_length=cfg.sequence_len * 64,
|
||||
)
|
||||
return strat
|
||||
@@ -11,6 +11,7 @@ import torch
|
||||
import transformers.modelcard
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from peft import PeftModel
|
||||
from pkg_resources import get_distribution # type: ignore
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
@@ -23,11 +24,6 @@ from axolotl.utils.freeze import freeze_parameters_except
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
try:
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
except ImportError:
|
||||
BetterTransformer = None
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
@@ -61,21 +57,6 @@ def train(
|
||||
eval_dataset = dataset_meta.eval_dataset
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
possible_checkpoints = [
|
||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
]
|
||||
if len(possible_checkpoints) > 0:
|
||||
sorted_paths = sorted(
|
||||
possible_checkpoints,
|
||||
key=lambda path: int(path.split("-")[-1]),
|
||||
)
|
||||
cfg.resume_from_checkpoint = sorted_paths[-1]
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||
)
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
|
||||
# Load the model and tokenizer
|
||||
msg = "loading model"
|
||||
if cfg.adapter:
|
||||
@@ -98,6 +79,21 @@ def train(
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
possible_checkpoints = [
|
||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
]
|
||||
if len(possible_checkpoints) > 0:
|
||||
sorted_paths = sorted(
|
||||
possible_checkpoints,
|
||||
key=lambda path: int(path.split("-")[-1]),
|
||||
)
|
||||
cfg.resume_from_checkpoint = sorted_paths[-1]
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||
)
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
||||
|
||||
@@ -128,7 +124,7 @@ def train(
|
||||
if cfg.local_rank == 0:
|
||||
|
||||
def terminate_handler(_, __, model):
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
sys.exit(0)
|
||||
@@ -153,10 +149,7 @@ def train(
|
||||
pretrain_hooks(cfg, trainer)
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||
enable_flash=True,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=True,
|
||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||
):
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
@@ -202,7 +195,7 @@ def train(
|
||||
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
|
||||
)
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
@@ -19,7 +19,6 @@ def chat_templates(user_choice: str):
|
||||
"""
|
||||
|
||||
templates = {
|
||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
}
|
||||
|
||||
@@ -132,26 +132,24 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
if not isinstance(features[0], list):
|
||||
features = [features]
|
||||
out_features = [{} for _ in features]
|
||||
for i, features_ in enumerate(features):
|
||||
for feature in features_[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(1) * np.array(item[feature])
|
||||
for i, item in enumerate(features_)
|
||||
if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
chunked_data = {}
|
||||
for feature in features[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(1) * np.array(item[feature])
|
||||
for item in features
|
||||
if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -161,26 +159,24 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
if not isinstance(features[0], list):
|
||||
features = [features]
|
||||
out_features = [{} for _ in features]
|
||||
for i, features_ in enumerate(features):
|
||||
for feature in features_[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(i + 1) * np.array(item[feature])
|
||||
for i, item in enumerate(features_)
|
||||
if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
chunked_data = {}
|
||||
for feature in features[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(i + 1) * np.array(item[feature])
|
||||
for i, item in enumerate(features)
|
||||
if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -202,20 +202,6 @@ def validate_config(cfg):
|
||||
raise ValueError(
|
||||
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
||||
)
|
||||
if (
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
not (cfg.bf16 or cfg.bfloat16)
|
||||
and (cfg.fp16 or cfg.float16)
|
||||
and not cfg.adapter
|
||||
and not cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
LOG.warning(
|
||||
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
|
||||
)
|
||||
# ValueError: Attempting to unscale FP16 gradients.
|
||||
# OR
|
||||
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
|
||||
if cfg.max_packed_sequence_len:
|
||||
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
||||
|
||||
@@ -246,6 +232,9 @@ def validate_config(cfg):
|
||||
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
||||
)
|
||||
|
||||
if cfg.load_4bit:
|
||||
raise ValueError("cfg.load_4bit parameter has been deprecated")
|
||||
|
||||
if cfg.adapter == "qlora":
|
||||
if cfg.merge_lora:
|
||||
# can't merge qlora if loaded in 8bit or 4bit
|
||||
@@ -271,8 +260,7 @@ def validate_config(cfg):
|
||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with QLoRA")
|
||||
|
||||
loftq = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
||||
if not cfg.load_in_8bit and cfg.adapter == "lora" and not loftq:
|
||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||
|
||||
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
||||
@@ -322,7 +310,7 @@ def validate_config(cfg):
|
||||
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
|
||||
if cfg.fp16 or cfg.bf16:
|
||||
raise ValueError("AMP is not supported with BetterTransformer")
|
||||
if cfg.float16 is not True and cfg.bfloat16 is not True:
|
||||
if cfg.float16 is not True and cfg.bloat16 is not True:
|
||||
LOG.warning(
|
||||
"You should probably set bfloat16 or float16 to true to "
|
||||
"load the model in float16 for BetterTransformers"
|
||||
@@ -352,11 +340,6 @@ def validate_config(cfg):
|
||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||
)
|
||||
|
||||
if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
|
||||
LOG.warning(
|
||||
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
||||
)
|
||||
|
||||
if cfg.gptq and cfg.model_revision:
|
||||
raise ValueError(
|
||||
"model_revision is not supported for GPTQ models. "
|
||||
@@ -364,24 +347,17 @@ def validate_config(cfg):
|
||||
+ "point to its path, and remove model_revision from the config."
|
||||
)
|
||||
|
||||
# if cfg.sample_packing and cfg.sdp_attention:
|
||||
# # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||
# raise ValueError(
|
||||
# "sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||
# )
|
||||
if cfg.sample_packing and cfg.sdp_attention:
|
||||
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.xformers_attention:
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16):
|
||||
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
|
||||
LOG.warning(
|
||||
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
|
||||
"This may work on H100s."
|
||||
)
|
||||
|
||||
if cfg.early_stopping_patience:
|
||||
if not cfg.save_steps or not cfg.eval_steps:
|
||||
raise ValueError(
|
||||
@@ -447,11 +423,7 @@ def validate_config(cfg):
|
||||
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.val_set_size == 0
|
||||
and (cfg.eval_steps or cfg.evaluation_strategy)
|
||||
and not cfg.test_datasets
|
||||
):
|
||||
if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
|
||||
raise ValueError(
|
||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ import hashlib
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
@@ -88,21 +88,12 @@ def prepare_dataset(cfg, tokenizer):
|
||||
path = cfg.pretraining_dataset[0]["path"]
|
||||
name = cfg.pretraining_dataset[0]["name"]
|
||||
|
||||
ds_wrapper_partial = functools.partial(
|
||||
get_dataset_wrapper,
|
||||
cfg.pretraining_dataset[0],
|
||||
train_dataset = load_pretraining_dataset(
|
||||
path,
|
||||
tokenizer,
|
||||
cfg,
|
||||
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
||||
)
|
||||
|
||||
train_dataset = wrap_pretraining_dataset(
|
||||
load_dataset(path, streaming=True, split="train", name=name),
|
||||
tokenizer,
|
||||
cfg,
|
||||
ds_wrapper_partial,
|
||||
name=name,
|
||||
max_tokens=cfg.sequence_len,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
seed=cfg.seed or 42,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
@@ -149,7 +140,7 @@ def load_tokenized_prepared_datasets(
|
||||
+ "|".join(
|
||||
sorted(
|
||||
[
|
||||
f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}"
|
||||
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
|
||||
for d in cfg_datasets
|
||||
]
|
||||
)
|
||||
@@ -392,9 +383,9 @@ def load_tokenized_prepared_datasets(
|
||||
|
||||
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
||||
config_dataset=config_dataset,
|
||||
dataset=ds,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
dataset=ds,
|
||||
d_base_type=d_base_type,
|
||||
d_prompt_style=d_prompt_style,
|
||||
)
|
||||
@@ -449,7 +440,7 @@ def load_prepare_datasets(
|
||||
split="train",
|
||||
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
||||
dataset, prompters = load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, default_dataset_prepared_path, split=split
|
||||
tokenizer, cfg, default_dataset_prepared_path
|
||||
)
|
||||
|
||||
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
||||
@@ -505,12 +496,7 @@ def load_prepare_datasets(
|
||||
|
||||
|
||||
def get_dataset_wrapper(
|
||||
config_dataset,
|
||||
tokenizer,
|
||||
cfg,
|
||||
d_base_type,
|
||||
dataset,
|
||||
d_prompt_style=None,
|
||||
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
|
||||
):
|
||||
dataset_wrapper = None
|
||||
dataset_prompter = None
|
||||
@@ -521,8 +507,7 @@ def get_dataset_wrapper(
|
||||
}
|
||||
|
||||
if (
|
||||
isinstance(dataset, Dataset)
|
||||
and "input_ids" in dataset.features
|
||||
"input_ids" in dataset.features
|
||||
and "attention_mask" in dataset.features
|
||||
and "labels" in dataset.features
|
||||
):
|
||||
@@ -780,67 +765,76 @@ def encode_pretraining(
|
||||
return ret
|
||||
|
||||
|
||||
def wrap_pretraining_dataset(
|
||||
dataset,
|
||||
tokenizer,
|
||||
cfg,
|
||||
ds_wrapper_fn,
|
||||
max_tokens=2048,
|
||||
batch_size=1,
|
||||
seed=42,
|
||||
buffer_size=10_000,
|
||||
):
|
||||
def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
|
||||
if cfg.sample_packing:
|
||||
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
pad_to_multiple_of=max_tokens * batch_size,
|
||||
pad_to_multiple_of=max_tokens * cfg.micro_batch_size,
|
||||
)
|
||||
encode = functools.partial(
|
||||
encode_packed_pretraining,
|
||||
tokenizer,
|
||||
collate_fn,
|
||||
ds_wrapper_fn,
|
||||
max_seq_length=max_tokens,
|
||||
batch_size=batch_size,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
)
|
||||
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
||||
cfg.micro_batch_size = 1
|
||||
else:
|
||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||
|
||||
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
||||
dataset = load_dataset(path, streaming=True, split="train", name=name)
|
||||
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
||||
dataset = dataset.map(
|
||||
encode,
|
||||
batched=True,
|
||||
batch_size=buffer_size,
|
||||
# input_columns="text",
|
||||
batch_size=10_000,
|
||||
input_columns="text",
|
||||
# remove all the existing columns after mapping since they end up having
|
||||
# a different length than the encoded/tokenized column
|
||||
remove_columns=dataset.features.keys(),
|
||||
desc="Encoding Pretraining",
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
def encode_packed_pretraining(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
collate_fn,
|
||||
ds_wrapper: Callable,
|
||||
examples: Dict[str, List],
|
||||
examples: List[str],
|
||||
max_seq_length: int = 2048,
|
||||
batch_size: int = 4,
|
||||
) -> Dict[str, List]:
|
||||
# pylint: disable=duplicate-code
|
||||
# tokenize all the examples
|
||||
# rows get split with stride (overlap)
|
||||
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
|
||||
res = tokenizer(
|
||||
examples,
|
||||
truncation=True,
|
||||
max_length=max_seq_length - 1,
|
||||
add_special_tokens=True,
|
||||
return_overflowing_tokens=True,
|
||||
stride=256,
|
||||
)
|
||||
|
||||
input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
|
||||
attention_mask = [seq + [1] for seq in res["attention_mask"]]
|
||||
|
||||
tokenized_examples = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
train_dataset = Dataset.from_dict(tokenized_examples)
|
||||
train_dataset = process_pretraining_datasets_for_packing(
|
||||
train_dataset, max_seq_length
|
||||
)
|
||||
|
||||
sampler = MultipackBatchSampler(
|
||||
RandomSampler(train_dataset),
|
||||
batch_size=1,
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=batch_size * max_seq_length,
|
||||
lengths=get_dataset_lengths(train_dataset),
|
||||
@@ -848,23 +842,15 @@ def encode_packed_pretraining(
|
||||
|
||||
chunked_data = defaultdict(list)
|
||||
|
||||
for batch in sampler:
|
||||
for data in batch:
|
||||
features = train_dataset[data]
|
||||
if "num_truncated_tokens" in features:
|
||||
del features["num_truncated_tokens"]
|
||||
if "num_truncated_tokens" in features:
|
||||
del features["num_truncated_tokens"]
|
||||
if "overflow_to_sample_mapping" in features:
|
||||
del features["overflow_to_sample_mapping"]
|
||||
if "labels" not in features:
|
||||
features["labels"] = features["input_ids"].copy()
|
||||
collated_features = collate_fn(features)
|
||||
for data in sampler:
|
||||
features = train_dataset[data]
|
||||
features["labels"] = features["input_ids"].copy()
|
||||
collated_features = collate_fn(features)
|
||||
|
||||
for feature in features.keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
chunked_data[feature].append(collated_features[feature].squeeze(0))
|
||||
for feature in features.keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
chunked_data[feature].append(collated_features[feature].squeeze(0))
|
||||
|
||||
return chunked_data
|
||||
|
||||
|
||||
@@ -8,13 +8,8 @@ import addict
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from peft import (
|
||||
LoftQConfig,
|
||||
PeftConfig,
|
||||
PeftModel,
|
||||
PeftModelForCausalLM,
|
||||
prepare_model_for_kbit_training,
|
||||
)
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from peft import PeftConfig, prepare_model_for_kbit_training
|
||||
from peft.tuners.lora import QuantLinear
|
||||
from transformers import ( # noqa: F401
|
||||
AddedToken,
|
||||
@@ -166,20 +161,15 @@ def load_tokenizer(cfg):
|
||||
if getattr(tokenizer, attr_name) is None:
|
||||
setattr(tokenizer, attr_name, "<|endoftext|>")
|
||||
|
||||
additional_special_tokens = None
|
||||
if cfg.special_tokens:
|
||||
special_tokens = cfg.special_tokens.to_dict()
|
||||
additional_special_tokens = special_tokens.pop(
|
||||
"additional_special_tokens", None
|
||||
)
|
||||
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
|
||||
for k, val in special_tokens.items():
|
||||
for k, val in cfg.special_tokens.items():
|
||||
# check if new special token is not already in tokenizer and
|
||||
# is adapter training to make sure lora_modules_to_save is set
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if (
|
||||
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
|
||||
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
|
||||
and (len(tokenizer.encode(val)) > 1)
|
||||
and cfg.adapter
|
||||
and (
|
||||
not cfg.lora_modules_to_save
|
||||
@@ -223,21 +213,6 @@ def load_tokenizer(cfg):
|
||||
]
|
||||
)
|
||||
|
||||
# Additional special tokens are a List, and need to be treated differently than regular special
|
||||
# tokens. We add them after we have called `add_tokens` in case these additional special tokens
|
||||
# are new tokens.
|
||||
#
|
||||
# Usage:
|
||||
#
|
||||
# ```py
|
||||
# special_tokens:
|
||||
# additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
|
||||
# ```
|
||||
if additional_special_tokens is not None:
|
||||
tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": additional_special_tokens}
|
||||
)
|
||||
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
@@ -329,13 +304,13 @@ def load_model(
|
||||
|
||||
LOG.info("patching with xformers attention")
|
||||
hijack_llama_attention()
|
||||
elif cfg.sample_packing:
|
||||
from axolotl.monkeypatch.llama_patch_multipack import (
|
||||
hijack_llama_prepare_4d_mask,
|
||||
elif cfg.sdp_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_sdp import (
|
||||
hijack_llama_sdp_attention,
|
||||
)
|
||||
|
||||
LOG.info("patching llama _prepare_4d_causal_attention_mask*")
|
||||
hijack_llama_prepare_4d_mask()
|
||||
LOG.info("patching with sdp attention")
|
||||
hijack_llama_sdp_attention()
|
||||
elif cfg.s2_attention:
|
||||
raise NotImplementedError(
|
||||
"Shifted-sparse attention not currently implemented without flash attention."
|
||||
@@ -478,18 +453,6 @@ def load_model(
|
||||
**bnb_config,
|
||||
)
|
||||
|
||||
if cfg.load_in_8bit and cfg.adapter is not None:
|
||||
model_kwargs["load_in_8bit"] = True
|
||||
if cfg.load_in_4bit and cfg.adapter is not None:
|
||||
model_kwargs["load_in_4bit"] = True
|
||||
|
||||
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
||||
if "quantization_config" in model_kwargs or cfg.gptq:
|
||||
if "load_in_8bit" in model_kwargs:
|
||||
del model_kwargs["load_in_8bit"]
|
||||
if "load_in_4bit" in model_kwargs:
|
||||
del model_kwargs["load_in_4bit"]
|
||||
|
||||
# sample packing uses custom FA2 patch
|
||||
if cfg.flash_attention:
|
||||
if not cfg.sample_packing:
|
||||
@@ -511,12 +474,6 @@ def load_model(
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"eager"
|
||||
)
|
||||
elif cfg.sdp_attention:
|
||||
model_kwargs["attn_implementation"] = "sdpa"
|
||||
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
|
||||
elif cfg.eager_attention:
|
||||
model_kwargs["attn_implementation"] = "eager"
|
||||
model_config._attn_implementation = "eager" # pylint: disable=protected-access
|
||||
|
||||
try:
|
||||
if (
|
||||
@@ -529,6 +486,8 @@ def load_model(
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@@ -596,6 +555,8 @@ def load_model(
|
||||
model = getattr(transformers, model_type).from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -627,6 +588,8 @@ def load_model(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -634,9 +597,6 @@ def load_model(
|
||||
LOG.exception(err)
|
||||
raise err
|
||||
|
||||
if isinstance(model, (PeftModel, PeftModelForCausalLM)):
|
||||
model = model.merge_and_unload()
|
||||
|
||||
embeddings_len = (
|
||||
math.ceil(len(tokenizer) / 32) * 32
|
||||
if cfg.resize_token_embeddings_to_32x
|
||||
@@ -685,7 +645,7 @@ def load_model(
|
||||
if not cfg.fsdp:
|
||||
# FSDP doesn't like mixed Float and BFloat16
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name or name.endswith(".gate"):
|
||||
if any(m in name for m in ["norm", "gate"]):
|
||||
module.to(torch.float32)
|
||||
if model_config.model_type == "btlm":
|
||||
# don't upcast lm_head for btlm
|
||||
@@ -698,9 +658,7 @@ def load_model(
|
||||
skip_prepare_model_for_kbit_training = False
|
||||
|
||||
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||
from deepspeed.utils import ( # pylint: disable=no-name-in-module
|
||||
set_z3_leaf_modules,
|
||||
)
|
||||
from deepspeed.utils import set_z3_leaf_modules
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
||||
@@ -709,17 +667,13 @@ def load_model(
|
||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
||||
if cfg.adapter == "lora" and loftq_bits:
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if cfg.adapter in ["lora", "qlora"]:
|
||||
if (cfg.adapter == "lora" and load_in_8bit) or (
|
||||
cfg.adapter == "qlora" and cfg.load_in_4bit
|
||||
):
|
||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||
if cfg.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
if (
|
||||
cfg.load_in_8bit or cfg.load_in_4bit
|
||||
) and not skip_prepare_model_for_kbit_training:
|
||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||
if not skip_prepare_model_for_kbit_training:
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||
)
|
||||
@@ -746,7 +700,6 @@ def load_model(
|
||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||
|
||||
if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
|
||||
# TODO revaldate this conditional
|
||||
model.to(f"cuda:{cfg.local_rank}")
|
||||
|
||||
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
||||
@@ -763,8 +716,6 @@ def load_model(
|
||||
model.config.use_cache = False
|
||||
|
||||
if cfg.flash_optimum:
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
|
||||
model = BetterTransformer.transform(model)
|
||||
|
||||
if cfg.adapter is not None:
|
||||
@@ -791,7 +742,7 @@ def load_adapter(model, cfg, adapter, inference=False):
|
||||
|
||||
def load_llama_adapter(model, cfg):
|
||||
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
from peft import AdaptionPromptConfig, get_peft_model
|
||||
from peft import AdaptionPromptConfig, PeftModel, get_peft_model
|
||||
|
||||
peft_config = AdaptionPromptConfig(
|
||||
adapter_layers=cfg.peft_adapter.layers, # layers (L)
|
||||
@@ -800,7 +751,7 @@ def load_llama_adapter(model, cfg):
|
||||
)
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
LOG.debug("Loading pretrained PEFT - llama_adapter")
|
||||
LOG.debug("Loading pretained PEFT - llama_adapter")
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
cfg.lora_model_dir,
|
||||
@@ -837,7 +788,7 @@ def find_all_linear_names(model):
|
||||
def load_lora(model, cfg, inference=False, config_only=False):
|
||||
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
|
||||
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
lora_target_modules = list(cfg.lora_target_modules or [])
|
||||
|
||||
@@ -846,12 +797,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
LOG.info(f"found linear modules: {repr(linear_names)}")
|
||||
lora_target_modules = list(set(lora_target_modules + linear_names))
|
||||
|
||||
lora_config_kwargs = {}
|
||||
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
||||
if loftq_bits:
|
||||
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
|
||||
lora_config_kwargs["init_lora_weights"] = "loftq"
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=cfg.lora_r,
|
||||
lora_alpha=cfg.lora_alpha,
|
||||
@@ -862,14 +807,13 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
**lora_config_kwargs,
|
||||
)
|
||||
|
||||
if config_only:
|
||||
return None, lora_config
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
LOG.debug("Loading pretrained PEFT - LoRA")
|
||||
LOG.debug("Loading pretained PEFT - LoRA")
|
||||
model_kwargs: Any = {}
|
||||
if cfg.lora_on_cpu:
|
||||
model_kwargs["max_memory"] = {"cpu": "256GiB"}
|
||||
|
||||
@@ -117,7 +117,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
self.batch_size = batch_size
|
||||
self.batch_size = None
|
||||
self.batch_max_len = batch_max_len
|
||||
self.lengths: np.ndarray = lengths
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
@@ -147,13 +147,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
n=1,
|
||||
)
|
||||
|
||||
batches = [
|
||||
[
|
||||
[indices[b_idx] for b_idx in batch]
|
||||
for batch in batches[i : i + self.batch_size]
|
||||
]
|
||||
for i in range(0, len(batches), self.batch_size)
|
||||
]
|
||||
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
||||
|
||||
# statistics
|
||||
if set_stats:
|
||||
@@ -195,7 +189,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
0.99
|
||||
* lengths_sum_per_device
|
||||
/ self.packing_efficiency_estimate
|
||||
// (self.batch_max_len * self.batch_size)
|
||||
// self.batch_max_len
|
||||
)
|
||||
- 1
|
||||
),
|
||||
|
||||
@@ -237,17 +237,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
main_process_only=True,
|
||||
)
|
||||
else:
|
||||
if cfg.flash_attention:
|
||||
batch_size = 1
|
||||
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
|
||||
else:
|
||||
batch_size = cfg.micro_batch_size
|
||||
batch_max_len = cfg.sequence_len
|
||||
sampler = MultipackBatchSampler(
|
||||
sampler=RandomSampler(train_dataset),
|
||||
batch_size=batch_size,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_max_len=cfg.micro_batch_size * cfg.sequence_len,
|
||||
lengths=get_dataset_lengths(train_dataset),
|
||||
)
|
||||
|
||||
@@ -255,7 +249,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
train_dataset.remove_columns(["length"]),
|
||||
batch_sampler=sampler,
|
||||
)
|
||||
data_loader_len = len(data_loader) // batch_size
|
||||
data_loader_len = len(data_loader)
|
||||
actual_eff = sampler.efficiency()
|
||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
"""
|
||||
E2E tests for multipack fft llama using 4d attention masks
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import require_torch_2_1_1, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class Test4dMultipackLlama(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using 4d attention with multipack
|
||||
"""
|
||||
|
||||
@require_torch_2_1_1
|
||||
@with_temp_dir
|
||||
def test_sdp_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"flash_attention": False,
|
||||
"sdp_attention": True,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.1,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"fp16": True,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_torch_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"flash_attention": False,
|
||||
"sdp_attention": False,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"fp16": True,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
@@ -33,7 +33,6 @@ class TestFusedLlama(unittest.TestCase):
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"flash_attention": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"flash_attn_fuse_qkv": True,
|
||||
"flash_attn_fuse_mlp": True,
|
||||
"sample_packing": True,
|
||||
|
||||
@@ -7,6 +7,8 @@ import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
@@ -61,7 +63,6 @@ class TestMistral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
@@ -102,9 +103,12 @@ class TestMistral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
"""
|
||||
E2E tests for relora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestReLoraLlama(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_relora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_modules": ["q_proj", "v_proj"],
|
||||
"relora_steps": 25,
|
||||
"relora_warmup_steps": 5,
|
||||
"relora_anneal_steps": 5,
|
||||
"relora_cpu_offload": True,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"warmup_steps": 15,
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
@@ -4,9 +4,7 @@ helper utils for tests
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from functools import wraps
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -33,15 +31,3 @@ def most_recent_subdir(path):
|
||||
subdir = max(subdirectories, key=os.path.getctime)
|
||||
|
||||
return subdir
|
||||
|
||||
|
||||
def require_torch_2_1_1(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.1.1
|
||||
"""
|
||||
|
||||
def is_min_2_1_1():
|
||||
torch_version = version("torch")
|
||||
return torch_version >= "2.1.1"
|
||||
|
||||
return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)
|
||||
|
||||
@@ -30,20 +30,6 @@ class TestMonkeyPatchUtils(unittest.TestCase):
|
||||
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
||||
)
|
||||
|
||||
def test_get_cu_seqlens_from_pos_ids_2d(self):
|
||||
position_ids = torch.tensor(
|
||||
[
|
||||
[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0],
|
||||
[0, 1, 2, 3, 4, 0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 0],
|
||||
]
|
||||
)
|
||||
target_res = torch.tensor(
|
||||
[[0, 4, 7, 12, 14, 16], [0, 5, 8, 15, 16, 16]], dtype=torch.int32
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
||||
)
|
||||
|
||||
def test_get_max_seqlen_in_batch(self):
|
||||
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
||||
target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32)
|
||||
|
||||
@@ -39,6 +39,32 @@ class TestExpandMask(unittest.TestCase):
|
||||
# Check that the output matches the expected output
|
||||
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
||||
|
||||
def test_output_multipack(self):
|
||||
mask = torch.tensor([[1, 1, 1, 0], [2, 2, 3, 3]])
|
||||
dtype = torch.float32
|
||||
expected_output = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
|
||||
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||
[-3.4028e38, -3.4028e38, 0.0000e00, -3.4028e38],
|
||||
[-3.4028e38, -3.4028e38, 0.0000e00, 0.0000e00],
|
||||
]
|
||||
],
|
||||
]
|
||||
)
|
||||
# Check that the output matches the expected output
|
||||
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
"""Module for testing streaming dataset sequence packing"""
|
||||
import pytest
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies.completion import load
|
||||
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
|
||||
@pytest.fixture(name="tokenizer")
|
||||
def fixture_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
tokenizer.pad_token = "</s>"
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="max_seq_length")
|
||||
def fixture_max_seq_length():
|
||||
return 4096
|
||||
|
||||
|
||||
class TestBatchedSamplerPacking:
|
||||
"""
|
||||
Test class for packing streaming dataset sequences
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size, num_workers",
|
||||
[
|
||||
(1, 0),
|
||||
(2, 0),
|
||||
(1, 2),
|
||||
(2, 2),
|
||||
],
|
||||
)
|
||||
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
dataset = load_dataset(
|
||||
"Trelis/tiny-shakespeare",
|
||||
split="train",
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"train_on_inputs": True,
|
||||
"sequence_len": max_seq_length,
|
||||
}
|
||||
)
|
||||
ds_cfg = DictDefault(
|
||||
{
|
||||
"field": "Text",
|
||||
}
|
||||
)
|
||||
completion_strategy = load(tokenizer, cfg, ds_cfg)
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
completion_strategy,
|
||||
dataset,
|
||||
)
|
||||
train_dataset = concatenate_datasets([dataset_wrapper])
|
||||
batch_sampler = MultipackBatchSampler(
|
||||
sampler=RandomSampler(train_dataset),
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=max_seq_length,
|
||||
lengths=get_dataset_lengths(train_dataset),
|
||||
)
|
||||
|
||||
loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=V2BatchSamplerDataCollatorForSeq2Seq( # pylint: disable=unexpected-keyword-arg
|
||||
tokenizer=tokenizer,
|
||||
padding=True,
|
||||
pad_to_multiple_of=max_seq_length,
|
||||
return_tensors="pt",
|
||||
),
|
||||
num_workers=num_workers,
|
||||
)
|
||||
inputs = next(iter(loader))
|
||||
|
||||
assert inputs["input_ids"].shape == (batch_size, max_seq_length)
|
||||
assert inputs["labels"].shape == (batch_size, max_seq_length)
|
||||
assert inputs["attention_mask"].shape == (batch_size, max_seq_length)
|
||||
|
||||
assert inputs["input_ids"].tolist()[0][0] == 2
|
||||
assert inputs["labels"].tolist()[0][0] == -100
|
||||
assert inputs["attention_mask"].tolist()[0][0] == 0
|
||||
assert inputs["attention_mask"].tolist()[0][-1] > 1
|
||||
|
||||
if batch_size >= 2:
|
||||
assert inputs["input_ids"].tolist()[1][0] == 2
|
||||
assert inputs["labels"].tolist()[1][0] == -100
|
||||
assert inputs["attention_mask"].tolist()[1][0] == 0
|
||||
assert inputs["attention_mask"].tolist()[1][-1] > 1
|
||||
@@ -1,17 +1,17 @@
|
||||
"""Module for testing streaming dataset sequence packing"""
|
||||
import functools
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.data import encode_packed_pretraining
|
||||
|
||||
|
||||
class TestPretrainingPacking(unittest.TestCase):
|
||||
class TestPacking(unittest.TestCase):
|
||||
"""
|
||||
Test class for packing streaming dataset sequences
|
||||
"""
|
||||
@@ -20,6 +20,8 @@ class TestPretrainingPacking(unittest.TestCase):
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.pad_token = "</s>"
|
||||
self.max_seq_length = 2048
|
||||
self.batch_size = 2
|
||||
|
||||
def test_packing_stream_dataset(self):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -29,43 +31,30 @@ class TestPretrainingPacking(unittest.TestCase):
|
||||
streaming=True,
|
||||
)["train"]
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"pretraining_dataset": [
|
||||
{
|
||||
"path": "c4",
|
||||
"name": "en",
|
||||
"type": "pretrain",
|
||||
}
|
||||
],
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 2048,
|
||||
"micro_batch_size": 2,
|
||||
}
|
||||
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
pad_to_multiple_of=self.max_seq_length,
|
||||
)
|
||||
|
||||
ds_wrapper_partial = functools.partial(
|
||||
get_dataset_wrapper,
|
||||
cfg.pretraining_dataset[0],
|
||||
encode = partial(
|
||||
encode_packed_pretraining,
|
||||
self.tokenizer,
|
||||
cfg,
|
||||
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
||||
collate_fn,
|
||||
max_seq_length=self.max_seq_length,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
|
||||
original_bsz = cfg.micro_batch_size
|
||||
train_dataset = wrap_pretraining_dataset(
|
||||
dataset,
|
||||
self.tokenizer,
|
||||
cfg,
|
||||
ds_wrapper_partial,
|
||||
max_tokens=cfg.sequence_len,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
seed=cfg.seed or 42,
|
||||
dataset = dataset.map(
|
||||
encode,
|
||||
batched=True,
|
||||
input_columns="text",
|
||||
remove_columns=dataset.features.keys(),
|
||||
)
|
||||
|
||||
trainer_loader = DataLoader(
|
||||
train_dataset,
|
||||
dataset,
|
||||
batch_size=1,
|
||||
collate_fn=None,
|
||||
drop_last=True,
|
||||
@@ -75,16 +64,16 @@ class TestPretrainingPacking(unittest.TestCase):
|
||||
if idx > 10:
|
||||
break
|
||||
assert data["input_ids"].shape == torch.Size(
|
||||
[1, original_bsz * cfg.sequence_len]
|
||||
[1, self.batch_size * self.max_seq_length]
|
||||
)
|
||||
assert data["position_ids"].shape == torch.Size(
|
||||
[1, original_bsz * cfg.sequence_len]
|
||||
[1, self.batch_size * self.max_seq_length]
|
||||
)
|
||||
assert data["labels"].shape == torch.Size(
|
||||
[1, original_bsz * cfg.sequence_len]
|
||||
[1, self.batch_size * self.max_seq_length]
|
||||
)
|
||||
assert data["attention_mask"].shape == torch.Size(
|
||||
[1, original_bsz * cfg.sequence_len]
|
||||
[1, self.batch_size * self.max_seq_length]
|
||||
)
|
||||
idx += 1
|
||||
|
||||
|
||||
@@ -67,21 +67,6 @@ class TestTokenizers(unittest.TestCase):
|
||||
)
|
||||
load_tokenizer(cfg)
|
||||
|
||||
def test_add_additional_special_tokens(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"special_tokens": {"additional_special_tokens": ["<|im_start|>"]},
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
|
||||
self.assertEqual(len(tokenizer), 32001)
|
||||
|
||||
# ensure reloading the tokenizer again from cfg results in same vocab length
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
self.assertEqual(len(tokenizer), 32001)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -26,12 +26,21 @@ class BaseValidation(unittest.TestCase):
|
||||
self._caplog = caplog
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class ValidationTest(BaseValidation):
|
||||
"""
|
||||
Test the validation module
|
||||
"""
|
||||
|
||||
def test_load_4bit_deprecate(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"load_4bit": True,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_batch_size_unused_warning(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -689,22 +698,6 @@ class ValidationTest(BaseValidation):
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_hub_model_id_save_value_warns(self):
|
||||
cfg = DictDefault({"hub_model_id": "test"})
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert (
|
||||
"set without any models being saved" in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
def test_hub_model_id_save_value(self):
|
||||
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4})
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
|
||||
class ValidationCheckModelConfig(BaseValidation):
|
||||
"""
|
||||
|
||||
98
ui/main.py
98
ui/main.py
@@ -1,98 +0,0 @@
|
||||
"""
|
||||
This module is used to launch Axolotl with user defined configurations.
|
||||
"""
|
||||
|
||||
import gradio as gr
|
||||
import yaml
|
||||
|
||||
|
||||
def config(
|
||||
base_model,
|
||||
dataset,
|
||||
dataset_type,
|
||||
learn_rate,
|
||||
gradient_accumulation_steps,
|
||||
micro_batch_size,
|
||||
seq_length,
|
||||
num_epochs,
|
||||
output_dir,
|
||||
val_size,
|
||||
):
|
||||
"""
|
||||
This function generates a configuration dictionary and saves it as a yaml file.
|
||||
"""
|
||||
config_dict = {
|
||||
"base_model": base_model,
|
||||
"datasets": [{"path": dataset, "type": dataset_type}],
|
||||
"learning_rate": learn_rate,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"micro_batch_size": micro_batch_size,
|
||||
"sequence_len": seq_length,
|
||||
"num_epochs": num_epochs,
|
||||
"output_dir": output_dir,
|
||||
"val_set_size": val_size,
|
||||
}
|
||||
with open("config.yml", "w", encoding="utf-8") as file:
|
||||
yaml.dump(config_dict, file)
|
||||
print(config_dict)
|
||||
return yaml.dump(config_dict)
|
||||
|
||||
|
||||
with gr.Blocks(title="Axolotl Launcher") as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
# Axolotl Launcher
|
||||
Fill out the required fields below to create a training run.
|
||||
"""
|
||||
)
|
||||
with gr.Row():
|
||||
base_model_name = gr.Textbox(
|
||||
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", label="Base model"
|
||||
)
|
||||
|
||||
mode = gr.Radio(
|
||||
choices=["Full finetune", "QLoRA", "LoRA"],
|
||||
label="Training mode",
|
||||
info="FFT = 16 bit, Qlora = 4 bit, Lora = 8 bit",
|
||||
)
|
||||
with gr.Row():
|
||||
dataset_path = gr.Textbox("mhenrichsen/alpaca_2k_test", label="Dataset")
|
||||
dataset_type_name = gr.Dropdown(
|
||||
choices=["alpaca", "sharegpt"], label="Dataset type", value="alpaca"
|
||||
)
|
||||
with gr.Accordion("Hyperparameters", open=False):
|
||||
gr.Markdown("Choose hyperparameters")
|
||||
with gr.Row():
|
||||
learning_rate = gr.Number(0.000001, label="Learning rate")
|
||||
gradient_accumulation_steps_count = gr.Number(
|
||||
1, label="Gradient accumulation steps"
|
||||
)
|
||||
val_set_size_count = gr.Number(0, label="Validation size")
|
||||
|
||||
with gr.Row():
|
||||
micro_batch_size_count = gr.Number(1, label="Micro batch size")
|
||||
sequence_length = gr.Number(1024, label="Sequence length")
|
||||
num_epochs_count = gr.Number(1, label="Epochs")
|
||||
|
||||
output_dir_path = gr.Textbox("./model-out", label="Output directory")
|
||||
|
||||
create_config = gr.Button("Create config")
|
||||
output = gr.TextArea(label="Generated config")
|
||||
create_config.click(
|
||||
config,
|
||||
inputs=[
|
||||
base_model_name,
|
||||
dataset_path,
|
||||
dataset_type_name,
|
||||
learning_rate,
|
||||
gradient_accumulation_steps_count,
|
||||
micro_batch_size_count,
|
||||
sequence_length,
|
||||
num_epochs_count,
|
||||
output_dir_path,
|
||||
val_set_size_count,
|
||||
],
|
||||
outputs=output,
|
||||
)
|
||||
|
||||
demo.launch(debug=True, server_name="0.0.0.0", server_port=7860)
|
||||
Reference in New Issue
Block a user