Compare commits
11 Commits
mixtral_op
...
hamelsmu-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
856f5f6115 | ||
|
|
d25c34caa6 | ||
|
|
13e938149d | ||
|
|
85de004dd4 | ||
|
|
80ec7af358 | ||
|
|
f28e75513b | ||
|
|
5ada140ff0 | ||
|
|
712fd27b3f | ||
|
|
ef24342538 | ||
|
|
5ea3aa31f0 | ||
|
|
f1f60cb5b2 |
43
README.md
43
README.md
@@ -36,7 +36,9 @@ Features:
|
|||||||
- [Train](#train)
|
- [Train](#train)
|
||||||
- [Inference](#inference)
|
- [Inference](#inference)
|
||||||
- [Merge LORA to Base](#merge-lora-to-base)
|
- [Merge LORA to Base](#merge-lora-to-base)
|
||||||
|
- [Special Tokens](#special-tokens)
|
||||||
- [Common Errors](#common-errors-)
|
- [Common Errors](#common-errors-)
|
||||||
|
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
||||||
- [Need Help?](#need-help-)
|
- [Need Help?](#need-help-)
|
||||||
- [Badge](#badge-)
|
- [Badge](#badge-)
|
||||||
- [Community Showcase](#community-showcase)
|
- [Community Showcase](#community-showcase)
|
||||||
@@ -100,7 +102,7 @@ pip3 install -e '.[flash-attn,deepspeed]'
|
|||||||
```
|
```
|
||||||
|
|
||||||
### Usage
|
### Usage
|
||||||
```bash
|
```bashtet
|
||||||
# finetune lora
|
# finetune lora
|
||||||
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||||
|
|
||||||
@@ -251,6 +253,13 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
{"conversations": [{"from": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
|
- `llama-2`: the json is the same format as `sharegpt` above, with the following config (see the [config section](#config) for more details)
|
||||||
|
```yml
|
||||||
|
datasets:
|
||||||
|
- path: <your-path>
|
||||||
|
type: sharegpt
|
||||||
|
conversation: llama-2
|
||||||
|
```
|
||||||
- `completion`: raw corpus
|
- `completion`: raw corpus
|
||||||
```json
|
```json
|
||||||
{"text": "..."}
|
{"text": "..."}
|
||||||
@@ -774,7 +783,7 @@ max_grad_norm:
|
|||||||
# Augmentation techniques
|
# Augmentation techniques
|
||||||
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
||||||
# currently only supported on Llama and Mistral
|
# currently only supported on Llama and Mistral
|
||||||
noisy_embedding_alpha:
|
neftune_noise_alpha:
|
||||||
|
|
||||||
# Whether to bettertransformers
|
# Whether to bettertransformers
|
||||||
flash_optimum:
|
flash_optimum:
|
||||||
@@ -970,6 +979,22 @@ wandb_name:
|
|||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
```
|
```
|
||||||
|
|
||||||
|
##### Special Tokens
|
||||||
|
|
||||||
|
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
|
||||||
|
|
||||||
|
```yml
|
||||||
|
special_tokens:
|
||||||
|
bos_token: "<s>"
|
||||||
|
eos_token: "</s>"
|
||||||
|
unk_token: "<unk>"
|
||||||
|
tokens: # these are delimiters
|
||||||
|
- "<|im_start|>"
|
||||||
|
- "<|im_end|>"
|
||||||
|
```
|
||||||
|
|
||||||
|
When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.
|
||||||
|
|
||||||
### Inference
|
### Inference
|
||||||
|
|
||||||
Pass the appropriate flag to the train command:
|
Pass the appropriate flag to the train command:
|
||||||
@@ -1048,6 +1073,20 @@ It's safe to ignore it.
|
|||||||
|
|
||||||
See the [NCCL](docs/nccl.md) guide.
|
See the [NCCL](docs/nccl.md) guide.
|
||||||
|
|
||||||
|
|
||||||
|
### Tokenization Mismatch b/w Inference & Training
|
||||||
|
|
||||||
|
For many formats, Axolotl constructs prompts by concatenating token ids _after_ tokenizing strings. The reason for concatenating token ids rather than operating on strings is to maintain precise accounting for attention masks.
|
||||||
|
|
||||||
|
If you decode a prompt constructed by axolotl, you might see spaces between tokens (or lack thereof) that you do not expect, especially around delimiters and special tokens. When you are starting out with a new format, you should always do the following:
|
||||||
|
|
||||||
|
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
|
||||||
|
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
|
||||||
|
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same adjust your inference server accordingly.
|
||||||
|
4. As an additional troubleshooting step, you can look look at the token ids between 1 and 2 to make sure they are identical.
|
||||||
|
|
||||||
|
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example.
|
||||||
|
|
||||||
## Need help? 🙋♂️
|
## Need help? 🙋♂️
|
||||||
|
|
||||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
||||||
|
|||||||
39
deepspeed/zero3_bf16.json
Normal file
39
deepspeed/zero3_bf16.json
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
{
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 3,
|
||||||
|
"overlap_comm": true,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"sub_group_size": 0,
|
||||||
|
"reduce_bucket_size": "auto",
|
||||||
|
"stage3_prefetch_bucket_size": "auto",
|
||||||
|
"stage3_param_persistence_threshold": "auto",
|
||||||
|
"stage3_max_live_parameters": 0,
|
||||||
|
"stage3_max_reuse_distance": 0,
|
||||||
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"auto_cast": false,
|
||||||
|
"loss_scale": 0,
|
||||||
|
"initial_scale_power": 32,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"params": {
|
||||||
|
"lr": "auto",
|
||||||
|
"betas": "auto",
|
||||||
|
"eps": "auto",
|
||||||
|
"weight_decay": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
||||||
@@ -10,7 +10,7 @@ ARG PYTORCH_VERSION="2.0.1"
|
|||||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y vim curl
|
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,15 @@ dataset_prepared_path: last_run_prepared
|
|||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
||||||
|
unfrozen_parameters:
|
||||||
|
# - lm_head.*
|
||||||
|
# - model.embed_tokens.*
|
||||||
|
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
|
||||||
|
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
|
||||||
|
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
|
||||||
|
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
auto-gptq==0.5.1
|
auto-gptq==0.5.1
|
||||||
packaging
|
packaging
|
||||||
peft==0.6.0
|
peft==0.6.0
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git@e5079b0b2abcef11ecbdae60ba4a6636c57b725d
|
transformers @ git+https://github.com/huggingface/transformers.git@ebfdb9ca62205279d5019ef1403877461b3b2da4
|
||||||
tokenizers==0.15.0
|
tokenizers==0.15.0
|
||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.41.1
|
||||||
accelerate==0.24.1
|
accelerate==0.24.1
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ LOG = logging.getLogger("axolotl.cli.train")
|
|||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
|
|||||||
@@ -692,6 +692,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
||||||
else "cosine"
|
else "cosine"
|
||||||
)
|
)
|
||||||
|
training_arguments_kwargs["lr_scheduler_kwargs"] = (
|
||||||
|
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||||
|
)
|
||||||
training_arguments_kwargs["weight_decay"] = (
|
training_arguments_kwargs["weight_decay"] = (
|
||||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||||
)
|
)
|
||||||
@@ -712,6 +715,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs
|
training_arguments_kwargs
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||||
|
|
||||||
|
if self.cfg.neftune_noise_alpha is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"neftune_noise_alpha"
|
||||||
|
] = self.cfg.neftune_noise_alpha
|
||||||
|
|
||||||
training_args = (
|
training_args = (
|
||||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
**training_arguments_kwargs,
|
**training_arguments_kwargs,
|
||||||
|
|||||||
@@ -83,14 +83,21 @@ def get_turns( # pylint: disable=too-many-return-statements
|
|||||||
yield role + ":", ""
|
yield role + ":", ""
|
||||||
return
|
return
|
||||||
if self.sep_style == SeparatorStyle.LLAMA2:
|
if self.sep_style == SeparatorStyle.LLAMA2:
|
||||||
seps = [self.sep, self.sep2]
|
|
||||||
if self.system_message:
|
if self.system_message:
|
||||||
|
if self.messages:
|
||||||
|
# For llama, the system message is incorporated into the first human instruction
|
||||||
|
first_role, first_msg = self.messages[0]
|
||||||
|
if first_role == self.roles[0]:
|
||||||
|
system_prompt += first_msg
|
||||||
|
self.messages.pop(0)
|
||||||
yield "", system_prompt
|
yield "", system_prompt
|
||||||
else:
|
for i, (role, message) in enumerate(self.messages):
|
||||||
yield "", "[INST] "
|
|
||||||
for i, (role, message) in enumerate(self.messages[1:]):
|
|
||||||
if message:
|
if message:
|
||||||
yield role + " ", message + seps[i % 2]
|
if (i % 2 == 0 and not self.system_message) or (
|
||||||
|
i % 2 != 0 and self.system_message
|
||||||
|
):
|
||||||
|
role = "<s> " + role
|
||||||
|
yield role + " ", message
|
||||||
else:
|
else:
|
||||||
yield role, ""
|
yield role, ""
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,65 +0,0 @@
|
|||||||
"""
|
|
||||||
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
from peft import PeftModel
|
|
||||||
from transformers import PreTrainedModel
|
|
||||||
|
|
||||||
|
|
||||||
def patch_neft(alpha, model):
|
|
||||||
embeddings = None
|
|
||||||
if isinstance(model, PreTrainedModel):
|
|
||||||
embeddings = model.get_input_embeddings()
|
|
||||||
if isinstance(model, PeftModel):
|
|
||||||
embeddings = model.base_model.get_input_embeddings()
|
|
||||||
if not embeddings:
|
|
||||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
|
||||||
embeddings.noisy_embedding_alpha = alpha
|
|
||||||
old_forward = embeddings.forward
|
|
||||||
|
|
||||||
# This hack seems to be needed to properly use a custom forward pass
|
|
||||||
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
|
|
||||||
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
|
|
||||||
embeddings, embeddings.__class__
|
|
||||||
)
|
|
||||||
setattr(embeddings, "forward", bound_method)
|
|
||||||
|
|
||||||
embeddings._old_forward = old_forward # pylint: disable=protected-access
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def unpatch_neft(model):
|
|
||||||
embeddings = None
|
|
||||||
if isinstance(model, PreTrainedModel):
|
|
||||||
embeddings = model.get_input_embeddings()
|
|
||||||
if isinstance(model, PeftModel):
|
|
||||||
embeddings = model.base_model.get_input_embeddings()
|
|
||||||
if not embeddings:
|
|
||||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
|
||||||
if hasattr(embeddings, "_old_forward"):
|
|
||||||
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
|
|
||||||
del embeddings._old_forward # pylint: disable=protected-access
|
|
||||||
del embeddings.noisy_embedding_alpha
|
|
||||||
|
|
||||||
|
|
||||||
def neft_forward(self, inputs: torch.Tensor):
|
|
||||||
embeddings = self._old_forward(inputs) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
|
|
||||||
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
|
|
||||||
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
|
|
||||||
-mag_norm, mag_norm
|
|
||||||
)
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def pretrain_hook(cfg, trainer):
|
|
||||||
if cfg.noisy_embedding_alpha:
|
|
||||||
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
|
|
||||||
|
|
||||||
|
|
||||||
def post_train_hook(cfg, trainer):
|
|
||||||
if cfg.noisy_embedding_alpha:
|
|
||||||
unpatch_neft(trainer.model)
|
|
||||||
@@ -16,8 +16,8 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
|
|||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.monkeypatch import neft_embeddings
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.freeze import freeze_parameters_except
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
@@ -78,6 +78,9 @@ def train(
|
|||||||
)
|
)
|
||||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||||
|
|
||||||
|
if cfg.unfrozen_parameters:
|
||||||
|
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
||||||
|
|
||||||
trainer = setup_trainer(
|
trainer = setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||||
)
|
)
|
||||||
@@ -176,21 +179,19 @@ def train(
|
|||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def pretrain_hooks(cfg, trainer):
|
def pretrain_hooks(_cfg, _trainer):
|
||||||
"""
|
"""
|
||||||
Run hooks right before kicking off the training
|
Run hooks right before kicking off the training
|
||||||
:param cfg:
|
:param cfg:
|
||||||
:param trainer:
|
:param trainer:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
neft_embeddings.pretrain_hook(cfg, trainer)
|
|
||||||
|
|
||||||
|
|
||||||
def post_train_hooks(cfg, trainer):
|
def post_train_hooks(_cfg, _trainer):
|
||||||
"""
|
"""
|
||||||
Run hooks right after training completes
|
Run hooks right after training completes
|
||||||
:param cfg:
|
:param cfg:
|
||||||
:param trainer:
|
:param trainer:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
neft_embeddings.post_train_hook(cfg, trainer)
|
|
||||||
|
|||||||
@@ -434,6 +434,20 @@ def validate_config(cfg):
|
|||||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.noisy_embedding_alpha is not None:
|
||||||
|
# Deprecated, use neftune_noise_alpha
|
||||||
|
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
||||||
|
if cfg.neftune_noise_alpha is None:
|
||||||
|
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
|
||||||
|
else:
|
||||||
|
# User is providing both; bail and have them sort out their settings
|
||||||
|
raise ValueError(
|
||||||
|
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
||||||
|
raise ValueError("neftune_noise_alpha must be > 0.0")
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
38
src/axolotl/utils/freeze.py
Normal file
38
src/axolotl/utils/freeze.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""
|
||||||
|
module to freeze/unfreeze parameters by name
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.utils.freeze")
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_parameters_except(model, regex_patterns):
|
||||||
|
"""
|
||||||
|
Freezes all layers of the given model except for the layers that match given regex patterns.
|
||||||
|
Periods in the patterns are treated as literal periods, not as wildcard characters.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- model (nn.Module): The PyTorch model to be modified.
|
||||||
|
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None; the model is modified in place.
|
||||||
|
"""
|
||||||
|
# Escape periods and compile the regex patterns
|
||||||
|
compiled_patterns = [
|
||||||
|
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
|
||||||
|
]
|
||||||
|
|
||||||
|
# First, freeze all parameters in the model
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Unfreeze layers that match the regex patterns
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if any(pattern.match(name) for pattern in compiled_patterns):
|
||||||
|
if is_main_process():
|
||||||
|
LOG.debug(f"unfreezing {name}")
|
||||||
|
param.requires_grad = True
|
||||||
@@ -21,6 +21,7 @@ from transformers import ( # noqa: F401
|
|||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
@@ -285,6 +286,9 @@ def load_model(
|
|||||||
model_kwargs["max_memory"] = cfg.max_memory
|
model_kwargs["max_memory"] = cfg.max_memory
|
||||||
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
||||||
|
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
del model_kwargs["device_map"]
|
||||||
|
|
||||||
if cfg.model_revision:
|
if cfg.model_revision:
|
||||||
model_kwargs["revision"] = cfg.model_revision
|
model_kwargs["revision"] = cfg.model_revision
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
@@ -324,6 +328,10 @@ def load_model(
|
|||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
"eager"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||||
|
|||||||
@@ -276,6 +276,7 @@ def prepare_optim_env(cfg):
|
|||||||
setup_fsdp_envs(cfg)
|
setup_fsdp_envs(cfg)
|
||||||
elif cfg.deepspeed:
|
elif cfg.deepspeed:
|
||||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||||
|
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
|
|||||||
@@ -114,6 +114,76 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
in self._caplog.records[0].message
|
in self._caplog.records[0].message
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_sharegpt_llama(self):
|
||||||
|
"Make sure the sharegpt/llama is tokenized and formatted correctly."
|
||||||
|
prompter = ShareGPTPrompterV2(conversation="llama-2")
|
||||||
|
strat = ShareGPTPromptTokenizingStrategy(
|
||||||
|
prompter,
|
||||||
|
self.tokenizer,
|
||||||
|
False,
|
||||||
|
2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
def tokenize(conv):
|
||||||
|
return strat.tokenize_prompt(conv)["input_ids"]
|
||||||
|
|
||||||
|
def decode(ids):
|
||||||
|
return strat.tokenizer.decode(ids)
|
||||||
|
|
||||||
|
# Multi-turn conversations
|
||||||
|
multi_turn_conv = {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "lorem"},
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
{"from": "human", "value": "123"},
|
||||||
|
{"from": "gpt", "value": "sit"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
# fmt: off
|
||||||
|
mt_ids = tokenize(multi_turn_conv)
|
||||||
|
assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||||
|
assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||||
|
|
||||||
|
# Single-turn conversations
|
||||||
|
single_turn_conv = {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "lorem"},
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
st_ids = tokenize(single_turn_conv)
|
||||||
|
assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
|
||||||
|
assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
||||||
|
|
||||||
|
# No system message, single-turn
|
||||||
|
no_sys_conv = {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
ns_ids = tokenize(no_sys_conv)
|
||||||
|
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
||||||
|
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
||||||
|
|
||||||
|
# No system message, multi-turn
|
||||||
|
no_sys_mt_conv = {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
{"from": "human", "value": "123"},
|
||||||
|
{"from": "gpt", "value": "sit"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
ns_mt_ids = tokenize(no_sys_mt_conv)
|
||||||
|
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||||
|
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
def test_sharegpt_changes_roles(self):
|
def test_sharegpt_changes_roles(self):
|
||||||
conversation = {
|
conversation = {
|
||||||
"roles": ["USER", "CHARACTER"],
|
"roles": ["USER", "CHARACTER"],
|
||||||
|
|||||||
Reference in New Issue
Block a user