Compare commits
1 Commits
fix-ddp_fi
...
20240307-u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b432346e3 |
@@ -25,7 +25,7 @@ Features:
|
|||||||
- [Environment](#environment)
|
- [Environment](#environment)
|
||||||
- [Docker](#docker)
|
- [Docker](#docker)
|
||||||
- [Conda/Pip venv](#condapip-venv)
|
- [Conda/Pip venv](#condapip-venv)
|
||||||
- [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod
|
- [Cloud GPU](#cloud-gpu) - Latitude.sh, RunPod
|
||||||
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
|
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
|
||||||
- [Windows](#windows)
|
- [Windows](#windows)
|
||||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||||
@@ -199,7 +199,6 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
|||||||
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
|
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
|
||||||
|
|
||||||
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||||
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
|
|
||||||
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||||
|
|
||||||
#### Bare Metal Cloud GPU
|
#### Bare Metal Cloud GPU
|
||||||
@@ -1080,10 +1079,6 @@ fsdp_config:
|
|||||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
```
|
```
|
||||||
|
|
||||||
##### FSDP + QLoRA
|
|
||||||
|
|
||||||
Axolotl supports training with FSDP and QLoRA, see [these docs](docs/fsdp_qlora.md) for more information.
|
|
||||||
|
|
||||||
##### Weights & Biases Logging
|
##### Weights & Biases Logging
|
||||||
|
|
||||||
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
||||||
@@ -1303,6 +1298,4 @@ consider sponsoring the project via [GitHub Sponsors](https://github.com/sponsor
|
|||||||
|
|
||||||
#### 🥉 Bronze Sponsors - $500/mo
|
#### 🥉 Bronze Sponsors - $500/mo
|
||||||
|
|
||||||
- [JarvisLabs.ai](https://jarvislabs.ai)
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
"min_loss_scale": 1
|
"min_loss_scale": 1
|
||||||
},
|
},
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"gradient_clipping": "auto",
|
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
"wall_clock_breakdown": false
|
"wall_clock_breakdown": false
|
||||||
|
|||||||
@@ -20,7 +20,6 @@
|
|||||||
"min_loss_scale": 1
|
"min_loss_scale": 1
|
||||||
},
|
},
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"gradient_clipping": "auto",
|
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
"wall_clock_breakdown": false
|
"wall_clock_breakdown": false
|
||||||
|
|||||||
@@ -24,7 +24,6 @@
|
|||||||
"min_loss_scale": 1
|
"min_loss_scale": 1
|
||||||
},
|
},
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"gradient_clipping": "auto",
|
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
"wall_clock_breakdown": false
|
"wall_clock_breakdown": false
|
||||||
|
|||||||
@@ -24,7 +24,6 @@
|
|||||||
"min_loss_scale": 1
|
"min_loss_scale": 1
|
||||||
},
|
},
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"gradient_clipping": "auto",
|
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
"wall_clock_breakdown": false
|
"wall_clock_breakdown": false
|
||||||
|
|||||||
@@ -1,37 +0,0 @@
|
|||||||
# FDSP + QLoRA
|
|
||||||
|
|
||||||
## Background
|
|
||||||
|
|
||||||
Using FSDP with QLoRA is essential for **fine-tuning larger (70b+ parameter) LLMs on consumer GPUs.** For example, you can use FSDP + QLoRA to train a 70b model on two 24GB GPUs[^1].
|
|
||||||
|
|
||||||
Below, we describe how to use this feature in Axolotl.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
|
||||||
|
|
||||||
> ![Tip]
|
|
||||||
> See the [example config](#example-config) file in addition to reading these instructions.
|
|
||||||
|
|
||||||
1. Set `adapter: qlora` in your axolotl config file.
|
|
||||||
2. Enable FSDP in your axolotl config, as [described here](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#fsdp).
|
|
||||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
|
||||||
|
|
||||||
## Example Config
|
|
||||||
|
|
||||||
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.
|
|
||||||
|
|
||||||
## References
|
|
||||||
|
|
||||||
- [PR #1378](https://github.com/OpenAccess-AI-Collective/axolotl/pull/1378) enabling QLoRA in FSDP in Axolotl.
|
|
||||||
- [Blog Post](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the [Answer.AI](https://www.answer.ai/) team describing the work that enabled QLoRA in FSDP.
|
|
||||||
- Related HuggingFace PRs Enabling FDSP + QLoRA:
|
|
||||||
- Accelerate [PR#2544](https://github.com/huggingface/accelerate/pull/2544 )
|
|
||||||
- Transformers [PR#29587](https://github.com/huggingface/transformers/pull/29587)
|
|
||||||
- TRL [PR#1416](https://github.com/huggingface/trl/pull/1416)
|
|
||||||
- PEFT [PR#1550](https://github.com/huggingface/peft/pull/1550)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
[^1]: This was enabled by [this work](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the Answer.AI team.
|
|
||||||
@@ -21,7 +21,7 @@ lora_dropout: 0.05
|
|||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: false
|
sample_packing: true
|
||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
|
|||||||
@@ -1,70 +0,0 @@
|
|||||||
base_model: NousResearch/Llama-2-7b-hf
|
|
||||||
model_type: LlamaForCausalLM
|
|
||||||
tokenizer_type: LlamaTokenizer
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: yahma/alpaca-cleaned
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.05
|
|
||||||
output_dir: ./qlora-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 512
|
|
||||||
sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules:
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 4
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: paged_adamw_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.00001
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch: 4
|
|
||||||
eval_table_size:
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
- full_shard
|
|
||||||
fsdp_config:
|
|
||||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
base_model: mistralai/Mixtral-8x7B-v0.1
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: LlamaTokenizer
|
|
||||||
trust_remote_code: true
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: tatsu-lab/alpaca
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.02
|
|
||||||
output_dir: ./qlora-out
|
|
||||||
|
|
||||||
model_config:
|
|
||||||
output_router_logits: true
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 1024
|
|
||||||
sample_packing: false
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: paged_adamw_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
loss_watchdog_threshold: 5.0
|
|
||||||
loss_watchdog_patience: 3
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch: 4
|
|
||||||
eval_table_size:
|
|
||||||
eval_max_new_tokens: 128
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
- full_shard
|
|
||||||
fsdp_config:
|
|
||||||
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
|
||||||
special_tokens:
|
|
||||||
@@ -16,12 +16,12 @@ output_dir: ./qlora-out
|
|||||||
|
|
||||||
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
||||||
unfrozen_parameters:
|
unfrozen_parameters:
|
||||||
# - ^lm_head.weight$
|
# - lm_head.*
|
||||||
# - ^model.embed_tokens.weight$[:32000]
|
# - model.embed_tokens.*
|
||||||
# - model.layers.2[0-9]+.block_sparse_moe.gate
|
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
|
||||||
# - model.layers.2[0-9]+.block_sparse_moe.experts
|
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
|
||||||
# - model.layers.3[0-9]+.block_sparse_moe.gate
|
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
|
||||||
# - model.layers.3[0-9]+.block_sparse_moe.experts
|
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
|
||||||
|
|
||||||
model_config:
|
model_config:
|
||||||
output_router_logits: true
|
output_router_logits: true
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ packaging==23.2
|
|||||||
peft==0.9.0
|
peft==0.9.0
|
||||||
transformers==4.38.2
|
transformers==4.38.2
|
||||||
tokenizers==0.15.0
|
tokenizers==0.15.0
|
||||||
bitsandbytes>=0.43.0
|
bitsandbytes>=0.41.1
|
||||||
accelerate==0.26.1
|
accelerate==0.26.1
|
||||||
deepspeed==0.13.1
|
deepspeed==0.13.1
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
@@ -40,4 +40,3 @@ gcsfs
|
|||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl>=0.7.9
|
trl>=0.7.9
|
||||||
fastcore>=1.5.29
|
|
||||||
|
|||||||
@@ -1,55 +0,0 @@
|
|||||||
"""module for building the auto wrap policy for FSDP"""
|
|
||||||
import functools
|
|
||||||
|
|
||||||
from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
|
|
||||||
from torch.distributed.fsdp.wrap import (
|
|
||||||
_or_policy,
|
|
||||||
lambda_auto_wrap_policy,
|
|
||||||
transformer_auto_wrap_policy,
|
|
||||||
)
|
|
||||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
||||||
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
|
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
|
|
||||||
|
|
||||||
SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
|
|
||||||
"llama",
|
|
||||||
"mistral",
|
|
||||||
"mixtral",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_wrapping_policy_factory(model_type):
|
|
||||||
if model_type == "llama":
|
|
||||||
layer_to_wrap = LlamaDecoderLayer
|
|
||||||
elif model_type == "mistral":
|
|
||||||
layer_to_wrap = MistralDecoderLayer
|
|
||||||
elif model_type == "mixtral":
|
|
||||||
layer_to_wrap = MixtralDecoderLayer
|
|
||||||
|
|
||||||
def get_wrapping_policy():
|
|
||||||
"""This checks for lora layers (has weight and requires_grad)"""
|
|
||||||
|
|
||||||
def lambda_policy_fn(module):
|
|
||||||
return (
|
|
||||||
len(list(module.named_children())) == 0
|
|
||||||
and getattr(module, "weight", None) is not None
|
|
||||||
and module.weight.requires_grad
|
|
||||||
)
|
|
||||||
|
|
||||||
lambda_policy = functools.partial(
|
|
||||||
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
|
|
||||||
)
|
|
||||||
transformer_layer_name = layer_to_wrap
|
|
||||||
transformer_wrap_policy = functools.partial(
|
|
||||||
transformer_auto_wrap_policy,
|
|
||||||
transformer_layer_cls=(
|
|
||||||
PrefixEncoder,
|
|
||||||
PromptEncoder,
|
|
||||||
PromptEmbedding,
|
|
||||||
transformer_layer_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
policies = [lambda_policy, transformer_wrap_policy]
|
|
||||||
return functools.partial(_or_policy, policies=policies)
|
|
||||||
|
|
||||||
return get_wrapping_policy
|
|
||||||
@@ -8,7 +8,6 @@ import importlib
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@@ -18,10 +17,7 @@ from typing import List, Optional, Type, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import FullyShardedDataParallelPlugin
|
|
||||||
from accelerate.utils import str_to_bool
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from torch.distributed.fsdp import MixedPrecision
|
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -34,7 +30,6 @@ from transformers.trainer_utils import seed_worker
|
|||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
|
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
from axolotl.loraplus import create_loraplus_optimizer
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
@@ -196,10 +191,6 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=1e-6,
|
default=1e-6,
|
||||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||||
)
|
)
|
||||||
qlora: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "whether this is a qlora training"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
class AxolotlTrainer(Trainer):
|
||||||
@@ -477,56 +468,6 @@ class AxolotlTrainer(Trainer):
|
|||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
@wraps(Trainer.create_accelerator_and_postprocess)
|
|
||||||
def create_accelerator_and_postprocess(self):
|
|
||||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
||||||
res = super().create_accelerator_and_postprocess()
|
|
||||||
|
|
||||||
if self.args.qlora is False:
|
|
||||||
return res
|
|
||||||
|
|
||||||
# the rest of this method override is specific to fsdp + qlora (for now)
|
|
||||||
sync_module_states = (
|
|
||||||
str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
|
|
||||||
)
|
|
||||||
|
|
||||||
mp_policy = None
|
|
||||||
amp = os.environ["ACCELERATE_MIXED_PRECISION"]
|
|
||||||
if amp == "fp16":
|
|
||||||
mp_policy = MixedPrecision(
|
|
||||||
param_dtype=torch.float32,
|
|
||||||
reduce_dtype=torch.float32,
|
|
||||||
buffer_dtype=torch.float32,
|
|
||||||
)
|
|
||||||
elif amp == "bf16":
|
|
||||||
mp_policy = MixedPrecision(
|
|
||||||
param_dtype=torch.float32,
|
|
||||||
reduce_dtype=torch.float32,
|
|
||||||
buffer_dtype=torch.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If somehow we figure out how we want to parameterize we want to autocast buffers...
|
|
||||||
# mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
|
|
||||||
# load_param_skip_names = ['inv_freq']
|
|
||||||
|
|
||||||
if self.is_fsdp_enabled:
|
|
||||||
wrapping_policy = get_wrapping_policy_factory(self.args.model_type)
|
|
||||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
|
||||||
auto_wrap_policy=wrapping_policy(),
|
|
||||||
cpu_offload=False,
|
|
||||||
use_orig_params=False,
|
|
||||||
limit_all_gathers=True,
|
|
||||||
param_init_fn=lambda module: module.to_empty(
|
|
||||||
device=torch.device("cuda"), recurse=False
|
|
||||||
)
|
|
||||||
if (rank != 0 and sync_module_states)
|
|
||||||
else None,
|
|
||||||
mixed_precision_policy=mp_policy,
|
|
||||||
)
|
|
||||||
self.accelerator.state.fsdp_plugin = fsdp_plugin
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -800,7 +741,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
warmup_steps = None
|
|
||||||
if self.cfg.warmup_steps is not None:
|
if self.cfg.warmup_steps is not None:
|
||||||
warmup_steps = self.cfg.warmup_steps
|
warmup_steps = self.cfg.warmup_steps
|
||||||
elif self.cfg.warmup_ratio is not None:
|
elif self.cfg.warmup_ratio is not None:
|
||||||
@@ -846,9 +786,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.fsdp_config:
|
if self.cfg.fsdp_config:
|
||||||
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
|
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
|
||||||
|
|
||||||
if self.cfg.adapter == "qlora":
|
|
||||||
training_arguments_kwargs["qlora"] = True
|
|
||||||
|
|
||||||
# deepspeed
|
# deepspeed
|
||||||
if self.cfg.deepspeed:
|
if self.cfg.deepspeed:
|
||||||
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
||||||
@@ -1000,14 +937,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
and self.cfg.eval_steps
|
and self.cfg.eval_steps
|
||||||
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
||||||
) or False
|
) or False
|
||||||
ddp_find_unused_parameters = (
|
training_arguments_kwargs["ddp_find_unused_parameters"] = (
|
||||||
self.cfg.ddp_find_unused_parameters
|
False if self.cfg.ddp else None
|
||||||
if self.cfg.ddp_find_unused_parameters is not None
|
|
||||||
else (False if self.cfg.ddp else None)
|
|
||||||
)
|
)
|
||||||
training_arguments_kwargs[
|
|
||||||
"ddp_find_unused_parameters"
|
|
||||||
] = ddp_find_unused_parameters
|
|
||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
report_to = None
|
report_to = None
|
||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ class ColorfulFormatter(Formatter):
|
|||||||
|
|
||||||
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
||||||
"version": 1,
|
"version": 1,
|
||||||
"disable_existing_loggers": False,
|
|
||||||
"formatters": {
|
"formatters": {
|
||||||
"simple": {
|
"simple": {
|
||||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
|
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
"""multipack patching for v2 of sample packing"""
|
"""multipack patching for v2 of sample packing"""
|
||||||
import importlib
|
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import init_empty_weights
|
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM
|
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
||||||
@@ -15,12 +12,11 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"falcon",
|
"falcon",
|
||||||
"phi",
|
"phi",
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemmoe",
|
|
||||||
"starcoder2",
|
"starcoder2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def patch_for_multipack(model_type, model_name=None):
|
def patch_for_multipack(model_type):
|
||||||
if model_type == "mixtral":
|
if model_type == "mixtral":
|
||||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
@@ -47,15 +43,3 @@ def patch_for_multipack(model_type, model_name=None):
|
|||||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
elif model_type == "gemmoe":
|
|
||||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
|
||||||
# we need to load the model here in order for modeling_gemmoe to be available
|
|
||||||
with init_empty_weights():
|
|
||||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
|
||||||
module_name = model_config.__class__.__module__.replace(
|
|
||||||
".configuration_gemmoe", ".modeling_gemmoe"
|
|
||||||
)
|
|
||||||
modeling_gemmoe = importlib.import_module(module_name)
|
|
||||||
modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -24,25 +24,6 @@ def argilla(
|
|||||||
return transform_fn
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
def argilla_chat(
|
|
||||||
cfg,
|
|
||||||
**kwargs,
|
|
||||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
|
||||||
"""
|
|
||||||
for argilla/dpo-mix-7k conversations
|
|
||||||
"""
|
|
||||||
|
|
||||||
def transform_fn(sample):
|
|
||||||
sample[
|
|
||||||
"prompt"
|
|
||||||
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
|
||||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
|
||||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
|
||||||
return sample
|
|
||||||
|
|
||||||
return transform_fn
|
|
||||||
|
|
||||||
|
|
||||||
def icr(
|
def icr(
|
||||||
cfg,
|
cfg,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|||||||
@@ -1,15 +1,10 @@
|
|||||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||||
from axolotl.prompters import ShareGPTPrompterV2
|
from axolotl.prompters import ShareGPTPrompterV2
|
||||||
from axolotl.utils.tokenization import (
|
|
||||||
chatml_to_conversation,
|
|
||||||
merge_consecutive_messages,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def register_chatml_template(system_message=None):
|
def register_chatml_template(system_message=None):
|
||||||
@@ -24,16 +19,6 @@ def register_chatml_template(system_message=None):
|
|||||||
sep="<|im_end|>",
|
sep="<|im_end|>",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
register_conv_template(
|
|
||||||
Conversation(
|
|
||||||
name="chatml_glaive",
|
|
||||||
system_template="<|im_start|>system\n{system_message}",
|
|
||||||
system_message=system_message,
|
|
||||||
roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"],
|
|
||||||
sep_style=SeparatorStyle.CHATML,
|
|
||||||
sep="<|im_end|>",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
@@ -92,20 +77,6 @@ def load_guanaco(tokenizer, cfg):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
||||||
conversation = (
|
|
||||||
ds_cfg["conversation"]
|
|
||||||
if ds_cfg and "conversation" in ds_cfg
|
|
||||||
else "chatml_glaive"
|
|
||||||
)
|
|
||||||
return GlaiveShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(conversation=conversation),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||||
"""
|
"""
|
||||||
basic sharegpt strategy to grab conversations from the sample row
|
basic sharegpt strategy to grab conversations from the sample row
|
||||||
@@ -187,15 +158,3 @@ class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingSt
|
|||||||
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
|
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
|
||||||
]
|
]
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
|
||||||
class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
|
|
||||||
"""
|
|
||||||
sharegpt strategy that remaps glaive data to sharegpt format
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
|
||||||
conversation = chatml_to_conversation(prompt)
|
|
||||||
conversation = merge_consecutive_messages(conversation)
|
|
||||||
|
|
||||||
return conversation
|
|
||||||
|
|||||||
@@ -360,19 +360,11 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
LOG.warning(f"expected tuple, got {part}")
|
LOG.warning(f"expected tuple, got {part}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tool_role_label = None
|
user, assistant = conversation.roles
|
||||||
if len(conversation.roles) == 3:
|
|
||||||
(
|
|
||||||
user_role_label,
|
|
||||||
assistant_role_label,
|
|
||||||
tool_role_label,
|
|
||||||
) = conversation.roles
|
|
||||||
else:
|
|
||||||
user_role_label, assistant_role_label = conversation.roles
|
|
||||||
role, content = part
|
role, content = part
|
||||||
|
|
||||||
# Uses "in" because role contains extra characters
|
# Uses "in" because role contains extra characters
|
||||||
if user_role_label in role:
|
if user in role:
|
||||||
role = (
|
role = (
|
||||||
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
||||||
if role_remap
|
if role_remap
|
||||||
@@ -392,7 +384,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
else:
|
else:
|
||||||
# everything from this is masked out from the labels
|
# everything from this is masked out from the labels
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
elif assistant_role_label in role:
|
elif assistant in role:
|
||||||
role = (
|
role = (
|
||||||
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
||||||
if role_remap
|
if role_remap
|
||||||
@@ -434,8 +426,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
else:
|
else:
|
||||||
# everything from this is masked out from the labels
|
# everything from this is masked out from the labels
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
elif tool_role_label and tool_role_label in role:
|
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
|
||||||
else:
|
else:
|
||||||
LOG.warning(f"unhandled role: {role}")
|
LOG.warning(f"unhandled role: {role}")
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -267,8 +267,6 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|||||||
|
|
||||||
role_key_human = "human"
|
role_key_human = "human"
|
||||||
role_key_model = "gpt"
|
role_key_model = "gpt"
|
||||||
# Optional, only used for tool usage datasets.
|
|
||||||
role_key_tool = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -276,7 +274,6 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|||||||
conversation: Optional[Union[str, Conversation]] = None,
|
conversation: Optional[Union[str, Conversation]] = None,
|
||||||
role_key_human: Optional[str] = None,
|
role_key_human: Optional[str] = None,
|
||||||
role_key_model: Optional[str] = None,
|
role_key_model: Optional[str] = None,
|
||||||
role_key_tool: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
if conversation:
|
if conversation:
|
||||||
if isinstance(conversation, Conversation):
|
if isinstance(conversation, Conversation):
|
||||||
@@ -289,8 +286,6 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|||||||
self.role_key_human = role_key_human
|
self.role_key_human = role_key_human
|
||||||
if role_key_model:
|
if role_key_model:
|
||||||
self.role_key_model = role_key_model
|
self.role_key_model = role_key_model
|
||||||
if role_key_tool:
|
|
||||||
self.role_key_tool = role_key_tool
|
|
||||||
|
|
||||||
def _build_result(self, source):
|
def _build_result(self, source):
|
||||||
if len(source) < 2:
|
if len(source) < 2:
|
||||||
@@ -308,8 +303,6 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|||||||
source.pop(0)
|
source.pop(0)
|
||||||
|
|
||||||
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
|
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
|
||||||
if self.role_key_tool:
|
|
||||||
roles[self.role_key_tool] = conv.roles[2]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Apply prompt templates
|
# Apply prompt templates
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import torch
|
|||||||
import transformers.modelcard
|
import transformers.modelcard
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft import PeftModel
|
from peft import PeftModel, PeftModelForCausalLM
|
||||||
from pkg_resources import get_distribution # type: ignore
|
from pkg_resources import get_distribution # type: ignore
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
@@ -19,7 +19,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
|||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_parameters_except
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
@@ -99,7 +99,7 @@ def train(
|
|||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
if cfg.unfrozen_parameters:
|
if cfg.unfrozen_parameters:
|
||||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
||||||
|
|
||||||
trainer = setup_trainer(
|
trainer = setup_trainer(
|
||||||
cfg,
|
cfg,
|
||||||
@@ -207,6 +207,20 @@ def train(
|
|||||||
|
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
|
if cfg.adapter and isinstance(model, (PeftModel, PeftModelForCausalLM)):
|
||||||
|
model.to("cpu")
|
||||||
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
|
if cfg.local_rank == 0:
|
||||||
|
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
||||||
|
model.save_pretrained(
|
||||||
|
str(Path(cfg.output_dir) / "merged"),
|
||||||
|
safe_serialization=safe_serialization,
|
||||||
|
progressbar=True,
|
||||||
|
)
|
||||||
|
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
|
|
||||||
|
|
||||||
if not cfg.hub_model_id:
|
if not cfg.hub_model_id:
|
||||||
try:
|
try:
|
||||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ def check_cuda_device(default_value):
|
|||||||
or not torch.cuda.is_available()
|
or not torch.cuda.is_available()
|
||||||
or device == "auto"
|
or device == "auto"
|
||||||
or torch.device(device).type == "cpu"
|
or torch.device(device).type == "cpu"
|
||||||
or torch.device(device).type == "meta"
|
|
||||||
):
|
):
|
||||||
return default_value
|
return default_value
|
||||||
|
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Module for pydantic models for configuration
|
Module for pydantic models for configuration
|
||||||
"""
|
"""
|
||||||
# pylint: disable=too-many-lines
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -129,10 +128,8 @@ class RLType(str, Enum):
|
|||||||
class ChatTemplate(str, Enum):
|
class ChatTemplate(str, Enum):
|
||||||
"""Chat templates configuration subset"""
|
"""Chat templates configuration subset"""
|
||||||
|
|
||||||
alpaca = "alpaca" # pylint: disable=invalid-name
|
|
||||||
chatml = "chatml" # pylint: disable=invalid-name
|
chatml = "chatml" # pylint: disable=invalid-name
|
||||||
inst = "inst" # pylint: disable=invalid-name
|
inst = "inst" # pylint: disable=invalid-name
|
||||||
gemma = "gemma" # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class LoftQConfig(BaseModel):
|
class LoftQConfig(BaseModel):
|
||||||
@@ -182,7 +179,6 @@ class LoraConfig(BaseModel):
|
|||||||
peft_layers_to_transform: Optional[List[int]] = None
|
peft_layers_to_transform: Optional[List[int]] = None
|
||||||
peft: Optional[PeftConfig] = None
|
peft: Optional[PeftConfig] = None
|
||||||
peft_use_dora: Optional[bool] = None
|
peft_use_dora: Optional[bool] = None
|
||||||
peft_use_relora: Optional[bool] = None
|
|
||||||
|
|
||||||
lora_on_cpu: Optional[bool] = None
|
lora_on_cpu: Optional[bool] = None
|
||||||
gptq: Optional[bool] = None
|
gptq: Optional[bool] = None
|
||||||
@@ -515,12 +511,10 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
neftune_noise_alpha: Optional[float] = None
|
neftune_noise_alpha: Optional[float] = None
|
||||||
|
|
||||||
max_memory: Optional[
|
max_memory: Optional[Union[int, str]] = None
|
||||||
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
|
||||||
] = None
|
|
||||||
gpu_memory_limit: Optional[Union[int, str]] = None
|
gpu_memory_limit: Optional[Union[int, str]] = None
|
||||||
|
|
||||||
chat_template: Optional[ChatTemplate] = None
|
chat_template: Optional[Union[Literal["chatml", "inst"], ChatTemplate]] = None
|
||||||
default_system_message: Optional[str] = None
|
default_system_message: Optional[str] = None
|
||||||
|
|
||||||
# INTERNALS - document for now, generally not set externally
|
# INTERNALS - document for now, generally not set externally
|
||||||
@@ -995,10 +989,3 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_fsdp_deepspeed(cls, data):
|
|
||||||
if data.get("deepspeed") and data.get("fsdp"):
|
|
||||||
raise ValueError("deepspeed and fsdp cannot be used together.")
|
|
||||||
return data
|
|
||||||
|
|||||||
@@ -114,7 +114,9 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||||
if total_eval_steps == 0:
|
if total_eval_steps == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
|
"eval dataset split is too small for sample_packing. "
|
||||||
|
"You should set `eval_sample_packing: False` "
|
||||||
|
"or decrease the value of `eval_batch_size`. "
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.max_steps:
|
if cfg.max_steps:
|
||||||
|
|||||||
@@ -3,14 +3,13 @@ module to freeze/unfreeze parameters by name
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Callable, List, Tuple
|
|
||||||
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.freeze")
|
LOG = logging.getLogger("axolotl.utils.freeze")
|
||||||
|
|
||||||
|
|
||||||
def freeze_layers_except(model, regex_patterns):
|
def freeze_parameters_except(model, regex_patterns):
|
||||||
"""
|
"""
|
||||||
Freezes all layers of the given model except for the layers that match given regex patterns.
|
Freezes all layers of the given model except for the layers that match given regex patterns.
|
||||||
Periods in the patterns are treated as literal periods, not as wildcard characters.
|
Periods in the patterns are treated as literal periods, not as wildcard characters.
|
||||||
@@ -18,209 +17,22 @@ def freeze_layers_except(model, regex_patterns):
|
|||||||
Parameters:
|
Parameters:
|
||||||
- model (nn.Module): The PyTorch model to be modified.
|
- model (nn.Module): The PyTorch model to be modified.
|
||||||
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
|
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
|
||||||
Note that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names.
|
|
||||||
Also, to match the entire layer name, the pattern should start with "^" and end with "$", otherwise it will match any part of the layer name.
|
|
||||||
The range pattern part is optional and it is not compiled as a regex pattern which means you must put "$" before the range pattern if you want to match the entire layer name.
|
|
||||||
E.g., ["^model.embed_tokens.weight$[:32000]", "layers.2[0-9]+.block_sparse_moe.gate.[a-z]+$"]
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None; the model is modified in place.
|
None; the model is modified in place.
|
||||||
"""
|
"""
|
||||||
if isinstance(regex_patterns, str):
|
# Escape periods and compile the regex patterns
|
||||||
regex_patterns = [regex_patterns]
|
compiled_patterns = [
|
||||||
|
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
|
||||||
|
]
|
||||||
|
|
||||||
patterns = [LayerNamePattern(pattern) for pattern in regex_patterns]
|
# First, freeze all parameters in the model
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
# Unfreeze layers that match the regex patterns
|
# Unfreeze layers that match the regex patterns
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
param.requires_grad = False
|
if any(pattern.match(name) for pattern in compiled_patterns):
|
||||||
unfrozen_ranges = []
|
if is_main_process():
|
||||||
for pattern in patterns:
|
LOG.debug(f"unfreezing {name}")
|
||||||
if not pattern.match(name):
|
|
||||||
continue
|
|
||||||
|
|
||||||
param.requires_grad = True
|
param.requires_grad = True
|
||||||
|
|
||||||
if pattern.range is not None:
|
|
||||||
unfrozen_ranges.append(pattern.range)
|
|
||||||
|
|
||||||
merged_unfrozen_ranges = _merge_ranges(unfrozen_ranges, len(param))
|
|
||||||
|
|
||||||
if param.requires_grad and is_main_process():
|
|
||||||
unfrozen_ranges = (
|
|
||||||
f" with ranges {merged_unfrozen_ranges}"
|
|
||||||
if merged_unfrozen_ranges
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
LOG.debug(f"Unfrozen {name}{unfrozen_ranges}")
|
|
||||||
|
|
||||||
if not merged_unfrozen_ranges:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# The range list we need is actually the inverted of the merged ranges
|
|
||||||
ranges_to_freeze = _invert_ranges(merged_unfrozen_ranges, len(param))
|
|
||||||
|
|
||||||
param.register_hook(_create_freeze_parameters_hook(ranges_to_freeze))
|
|
||||||
|
|
||||||
if is_main_process() and all(
|
|
||||||
not param.requires_grad for param in model.parameters()
|
|
||||||
):
|
|
||||||
LOG.warning("All parameters are frozen. Model will not be trained.")
|
|
||||||
|
|
||||||
|
|
||||||
def _invert_ranges(
|
|
||||||
given_ranges: List[Tuple[int, int]], layer_size: int
|
|
||||||
) -> List[Tuple[int, int]]:
|
|
||||||
"""
|
|
||||||
Inverts a list of ranges to obtain the ranges not covered by the given ranges.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- given_ranges (List[Tuple[int, int]]): List of ranges to invert. Each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
|
|
||||||
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
|
|
||||||
Returns:
|
|
||||||
- List[Tuple[int, int]]: List of inverted ranges, where each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
|
|
||||||
"""
|
|
||||||
if not given_ranges:
|
|
||||||
return [(0, layer_size)]
|
|
||||||
|
|
||||||
inverted_ranges = []
|
|
||||||
current_start = 0
|
|
||||||
|
|
||||||
for start, end in sorted(given_ranges):
|
|
||||||
if start > current_start:
|
|
||||||
inverted_ranges.append((current_start, start))
|
|
||||||
current_start = max(current_start, end)
|
|
||||||
|
|
||||||
# Handle the case where the last given range does not reach the end of the total_size
|
|
||||||
if current_start < layer_size:
|
|
||||||
inverted_ranges.append((current_start, layer_size))
|
|
||||||
|
|
||||||
return inverted_ranges
|
|
||||||
|
|
||||||
|
|
||||||
def _merge_ranges(
|
|
||||||
given_ranges: List[Tuple[int, int | None]], layer_size: int
|
|
||||||
) -> List[Tuple[int, int]]:
|
|
||||||
"""
|
|
||||||
Merges overlapping ranges and sorts the given ranges.
|
|
||||||
|
|
||||||
This function takes a list of ranges and merges any overlapping ranges. The ranges are represented
|
|
||||||
as tuples, where the first element is the start index (inclusive) and the second element is the end
|
|
||||||
index (exclusive). The end index can be None, indicating that the range extends to the end of the
|
|
||||||
sequence.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- given_ranges (List[Tuple[int, int | None]]): List of ranges to merge.
|
|
||||||
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- List[Tuple[int, int]]: List of merged ranges, as start (inclusive) and end (exclusive) indices.
|
|
||||||
"""
|
|
||||||
# End of each range can be determined now since we have the total size
|
|
||||||
processed_ranges = [
|
|
||||||
(start, end if end is not None else layer_size) for start, end in given_ranges
|
|
||||||
]
|
|
||||||
|
|
||||||
# No need to merge if there's only one or no ranges
|
|
||||||
if len(processed_ranges) <= 1:
|
|
||||||
return processed_ranges
|
|
||||||
|
|
||||||
sorted_ranges = sorted(processed_ranges)
|
|
||||||
|
|
||||||
merged_ranges = [sorted_ranges[0]]
|
|
||||||
for start, end in sorted_ranges[1:]:
|
|
||||||
prev_start, prev_end = merged_ranges[-1]
|
|
||||||
if start <= prev_end:
|
|
||||||
merged_ranges[-1] = (prev_start, max(prev_end, end))
|
|
||||||
else:
|
|
||||||
merged_ranges.append((start, end))
|
|
||||||
|
|
||||||
return merged_ranges
|
|
||||||
|
|
||||||
|
|
||||||
def _create_freeze_parameters_hook(ranges_to_freeze: List[Tuple[int, int]]) -> Callable:
|
|
||||||
"""
|
|
||||||
Create a hook to freeze parameters in specified ranges by setting their gradients to zero.
|
|
||||||
|
|
||||||
This function takes a list of tuples representing the ranges of indices to freeze. Each tuple should contain
|
|
||||||
two integers representing the start and end indices of the range.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- ranges_to_freeze (List[Tuple[int, int]]): Ranges of indices to freeze.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- Callable: A hook function to be used with `register_hook` on parameters.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
```
|
|
||||||
ranges_to_freeze = [(0, 10), (20, 30)]
|
|
||||||
hook = _create_freeze_parameters_hook(ranges_to_freeze)
|
|
||||||
model.register_hook(hook)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def freeze_parameters_hook(gradients):
|
|
||||||
for start, end in ranges_to_freeze:
|
|
||||||
gradients[start:end].zero_()
|
|
||||||
|
|
||||||
return freeze_parameters_hook
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNamePattern:
|
|
||||||
"""
|
|
||||||
Represents a regex pattern for layer names, potentially including a parameter index range.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, pattern: str):
|
|
||||||
"""
|
|
||||||
Initializes a new instance of the LayerNamePattern class.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- pattern (str): The regex pattern for layer names, potentially including a parameter index range.
|
|
||||||
"""
|
|
||||||
self.raw_pattern = pattern
|
|
||||||
name_pattern, self.range = self._parse_pattern(pattern)
|
|
||||||
self.name_regex = re.compile(name_pattern.replace(".", "\\."))
|
|
||||||
|
|
||||||
def match(self, name: str) -> bool:
|
|
||||||
"""
|
|
||||||
Checks if the given layer name matches the regex pattern.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- name (str): The layer name to check.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- bool: True if the layer name matches the pattern, False otherwise.
|
|
||||||
"""
|
|
||||||
return self.name_regex.match(name) is not None
|
|
||||||
|
|
||||||
def _parse_pattern(self, pattern: str) -> Tuple[str, Tuple[int, int | None] | None]:
|
|
||||||
"""
|
|
||||||
Extracts the range pattern from the given pattern.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- pattern (str): The pattern to extract the range from.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- Tuple[str, Tuple[int, int | None] | None]: A tuple containing the regex pattern to match the layer name without the range pattern and the range of layer indices to match, if specified.
|
|
||||||
"""
|
|
||||||
match = re.match(r"^(.+)\[([0-9]*)(?::([0-9]*))?\]$", pattern)
|
|
||||||
if not match:
|
|
||||||
return pattern, None
|
|
||||||
|
|
||||||
base_pattern, start_part, end_part = match.groups()
|
|
||||||
|
|
||||||
if end_part is None and start_part.isdecimal():
|
|
||||||
index = int(start_part)
|
|
||||||
return base_pattern, (index, index + 1)
|
|
||||||
|
|
||||||
# [:end] or [start:] or [start:end]
|
|
||||||
start = int(start_part) if start_part else 0
|
|
||||||
end = int(end_part) if end_part else None
|
|
||||||
|
|
||||||
if end is not None and start >= end:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid range in layer name pattern: {pattern}."
|
|
||||||
"End of range must be greater than start."
|
|
||||||
)
|
|
||||||
return base_pattern, (start, end)
|
|
||||||
|
|||||||
@@ -1,20 +1,13 @@
|
|||||||
"""Module for models and model loading"""
|
"""Module for models and model loading"""
|
||||||
# pylint: disable=too-many-lines
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import types
|
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union # noqa: F401
|
|
||||||
|
|
||||||
import addict
|
import addict
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import safetensors
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import init_empty_weights
|
|
||||||
from bitsandbytes.nn import Linear4bit, Params4bit
|
|
||||||
from fastcore.parallel import parallel
|
|
||||||
from peft import (
|
from peft import (
|
||||||
LoftQConfig,
|
LoftQConfig,
|
||||||
PeftConfig,
|
PeftConfig,
|
||||||
@@ -23,7 +16,6 @@ from peft import (
|
|||||||
prepare_model_for_kbit_training,
|
prepare_model_for_kbit_training,
|
||||||
)
|
)
|
||||||
from peft.tuners.lora import QuantLinear
|
from peft.tuners.lora import QuantLinear
|
||||||
from torch import Tensor, nn
|
|
||||||
from transformers import ( # noqa: F401
|
from transformers import ( # noqa: F401
|
||||||
AddedToken,
|
AddedToken,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
@@ -35,9 +27,7 @@ from transformers import ( # noqa: F401
|
|||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
|
|
||||||
|
|
||||||
from axolotl.core.policies.auto_wrap import SUPPORTED_AUTO_WRAP_MODEL_TYPES
|
|
||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.monkeypatch.multipack import (
|
from axolotl.monkeypatch.multipack import (
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||||
@@ -272,117 +262,6 @@ def load_tokenizer(cfg):
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
def replace_linear(
|
|
||||||
model: nn.Module,
|
|
||||||
linear_replacement: Type[nn.Module],
|
|
||||||
quant_config: Union[dict, None] = None,
|
|
||||||
skip_modules=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Replace linear modules with a new Linear module.
|
|
||||||
Parameters:
|
|
||||||
model (`torch.nn.Module`):
|
|
||||||
Input model or `torch.nn.Module` as the function is run recursively.
|
|
||||||
linear_replacement (`torch.nn.Module`):
|
|
||||||
The linear module that replaces the old one. Only expects standard arguments.
|
|
||||||
If other arguments need to be passed, use a lambda.
|
|
||||||
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
|
|
||||||
List of modules names not to convert. Defaults to `lm_head`.
|
|
||||||
"""
|
|
||||||
if skip_modules is None:
|
|
||||||
skip_modules = ["lm_head"]
|
|
||||||
for name, module in model.named_children():
|
|
||||||
if len(list(module.children())) > 0:
|
|
||||||
replace_linear(
|
|
||||||
module, linear_replacement, quant_config, skip_modules, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
|
||||||
if issubclass(linear_replacement, Linear4bit):
|
|
||||||
model._modules[ # pylint: disable=protected-access
|
|
||||||
name
|
|
||||||
] = linear_replacement(
|
|
||||||
module.in_features,
|
|
||||||
module.out_features,
|
|
||||||
module.bias is not None,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported linear replacement: {type(linear_replacement)}"
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_and_quantize(
|
|
||||||
module: nn.Module,
|
|
||||||
name: str,
|
|
||||||
value: Tensor,
|
|
||||||
device: torch.device = None,
|
|
||||||
dtype: torch.dtype = None,
|
|
||||||
skip_names: Optional[List[str]] = None,
|
|
||||||
is_meta_rank: bool = False,
|
|
||||||
low_memory: bool = True,
|
|
||||||
verbose: bool = False,
|
|
||||||
quant_method: str = "bnb",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
|
|
||||||
|
|
||||||
Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if skip_names is None:
|
|
||||||
skip_names = []
|
|
||||||
|
|
||||||
def place_on_device(value):
|
|
||||||
if is_meta_rank:
|
|
||||||
device = "meta"
|
|
||||||
elif low_memory:
|
|
||||||
device = "cpu"
|
|
||||||
else:
|
|
||||||
device = "cuda"
|
|
||||||
return value.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if any(skip_name in name for skip_name in skip_names):
|
|
||||||
if verbose:
|
|
||||||
print(f"Skipping {name} because it is in skip_names")
|
|
||||||
return
|
|
||||||
|
|
||||||
module_key, _, value_key = name.rpartition(".")
|
|
||||||
try:
|
|
||||||
submodule = module.get_submodule(module_key)
|
|
||||||
except AttributeError as exc:
|
|
||||||
print(f"Module {module_key} not found:\n{exc}")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
if quant_method == "bnb":
|
|
||||||
param = submodule.get_parameter(value_key)
|
|
||||||
if isinstance(param, Params4bit):
|
|
||||||
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
|
|
||||||
# shape as the quantized Params4bit with an initialized quant_state. However,
|
|
||||||
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
|
|
||||||
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
|
|
||||||
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
|
|
||||||
value = type(param)(
|
|
||||||
value.to(device=device, dtype=dtype).data, **param.__dict__
|
|
||||||
).cuda(device)
|
|
||||||
if is_meta_rank:
|
|
||||||
value = type(param)(value.data.to("meta"), **value.__dict__)
|
|
||||||
elif low_memory:
|
|
||||||
value = type(param)(value.data.to("cpu"), **value.__dict__)
|
|
||||||
else:
|
|
||||||
value = type(param)(place_on_device(value).data)
|
|
||||||
|
|
||||||
except AttributeError:
|
|
||||||
# it's a buffer
|
|
||||||
value = place_on_device(value)
|
|
||||||
|
|
||||||
setattr(submodule, value_key, value)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
@@ -429,7 +308,7 @@ def load_model(
|
|||||||
and cfg.flash_attention
|
and cfg.flash_attention
|
||||||
and cfg.sample_packing
|
and cfg.sample_packing
|
||||||
):
|
):
|
||||||
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
|
patch_for_multipack(cfg.model_config_type)
|
||||||
elif cfg.is_llama_derived_model:
|
elif cfg.is_llama_derived_model:
|
||||||
# Modify all llama derived models in one block
|
# Modify all llama derived models in one block
|
||||||
|
|
||||||
@@ -515,7 +394,7 @@ def load_model(
|
|||||||
|
|
||||||
if max_memory is not None:
|
if max_memory is not None:
|
||||||
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
|
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
|
||||||
from accelerate import infer_auto_device_map
|
from accelerate import infer_auto_device_map, init_empty_weights
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model_canvas = AutoModelForCausalLM.from_config(model_config)
|
model_canvas = AutoModelForCausalLM.from_config(model_config)
|
||||||
@@ -617,78 +496,8 @@ def load_model(
|
|||||||
model_kwargs["attn_implementation"] = "eager"
|
model_kwargs["attn_implementation"] = "eager"
|
||||||
model_config._attn_implementation = "eager" # pylint: disable=protected-access
|
model_config._attn_implementation = "eager" # pylint: disable=protected-access
|
||||||
|
|
||||||
qlora_fsdp = (
|
|
||||||
cfg.fsdp
|
|
||||||
and cfg.adapter == "qlora"
|
|
||||||
and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if qlora_fsdp:
|
if (
|
||||||
if cfg.bf16 or cfg.bfloat16:
|
|
||||||
torch_dtype, compute_dtype = torch.float32, torch.bfloat16
|
|
||||||
elif cfg.fp16 or cfg.float16:
|
|
||||||
torch_dtype, compute_dtype = torch.float32, torch.float16
|
|
||||||
else:
|
|
||||||
torch_dtype, compute_dtype = torch.float32, torch.float16
|
|
||||||
|
|
||||||
with init_empty_weights():
|
|
||||||
LOG.info("Loading model with empty weights.")
|
|
||||||
model = AutoModelForCausalLM.from_config(model_config)
|
|
||||||
model.model = replace_linear(
|
|
||||||
model.model,
|
|
||||||
Linear4bit,
|
|
||||||
compute_dtype=compute_dtype,
|
|
||||||
quant_type="nf4",
|
|
||||||
quant_storage=torch_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
model.is_loaded_in_4bit = True
|
|
||||||
|
|
||||||
# Grab the safetensors files that hold the weights
|
|
||||||
try:
|
|
||||||
idx = hub.cached_file(base_model, SAFE_WEIGHTS_INDEX_NAME)
|
|
||||||
files, _ = hub.get_checkpoint_shard_files(base_model, idx)
|
|
||||||
except OSError:
|
|
||||||
try:
|
|
||||||
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
|
|
||||||
files = []
|
|
||||||
files.append(hub.cached_file(base_model, SAFE_WEIGHTS_NAME))
|
|
||||||
except OSError as exc:
|
|
||||||
# This means the model probably doesn't have a safetensors file
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
|
|
||||||
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
|
|
||||||
def load_and_quantize_parallel(name_param, model, **kwargs):
|
|
||||||
name, param = name_param
|
|
||||||
load_and_quantize(model, name, param, **kwargs)
|
|
||||||
|
|
||||||
param_count = sum((p.numel() for n, p in model.named_parameters()))
|
|
||||||
for filename in files:
|
|
||||||
weights = safetensors.torch.load_file(filename)
|
|
||||||
quant_method = "bnb"
|
|
||||||
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
|
|
||||||
left = int(os.cpu_count() / torch.cuda.device_count())
|
|
||||||
right = int(
|
|
||||||
8 * (devprops.total_memory / 1e9 / 40) * (70 / (param_count / 1e9))
|
|
||||||
)
|
|
||||||
n_workers = min(left, right)
|
|
||||||
parallel(
|
|
||||||
load_and_quantize_parallel,
|
|
||||||
weights.items(),
|
|
||||||
n_workers=n_workers,
|
|
||||||
threadpool=True,
|
|
||||||
model=model,
|
|
||||||
dtype=torch_dtype,
|
|
||||||
device=cfg.local_rank,
|
|
||||||
skip_names=[],
|
|
||||||
is_meta_rank=(cfg.local_rank != 0),
|
|
||||||
verbose=False,
|
|
||||||
quant_method=quant_method,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif (
|
|
||||||
model_config.model_type == "llama"
|
model_config.model_type == "llama"
|
||||||
and not cfg.trust_remote_code
|
and not cfg.trust_remote_code
|
||||||
and not cfg.gptq
|
and not cfg.gptq
|
||||||
@@ -804,7 +613,7 @@ def load_model(
|
|||||||
LOG.exception(err)
|
LOG.exception(err)
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp:
|
if isinstance(model, (PeftModel, PeftModelForCausalLM)):
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
embeddings_len = (
|
embeddings_len = (
|
||||||
@@ -883,9 +692,6 @@ def load_model(
|
|||||||
if cfg.adapter == "lora" and loftq_bits:
|
if cfg.adapter == "lora" and loftq_bits:
|
||||||
skip_prepare_model_for_kbit_training = True
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
if qlora_fsdp:
|
|
||||||
skip_prepare_model_for_kbit_training = True
|
|
||||||
|
|
||||||
if cfg.adapter in ["lora", "qlora"]:
|
if cfg.adapter in ["lora", "qlora"]:
|
||||||
if cfg.gradient_checkpointing:
|
if cfg.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
@@ -900,7 +706,7 @@ def load_model(
|
|||||||
|
|
||||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||||
if (needs_fa2_dtype or cfg.flash_attention) and not qlora_fsdp:
|
if needs_fa2_dtype or cfg.flash_attention:
|
||||||
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if "norm" in name:
|
if "norm" in name:
|
||||||
@@ -918,12 +724,7 @@ def load_model(
|
|||||||
else:
|
else:
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
|
|
||||||
if (
|
if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
|
||||||
cfg.ddp
|
|
||||||
and not load_in_8bit
|
|
||||||
and not (cfg.rl and cfg.load_in_4bit)
|
|
||||||
and not qlora_fsdp
|
|
||||||
):
|
|
||||||
# TODO revaldate this conditional
|
# TODO revaldate this conditional
|
||||||
model.to(f"cuda:{cfg.local_rank}")
|
model.to(f"cuda:{cfg.local_rank}")
|
||||||
|
|
||||||
@@ -1012,30 +813,6 @@ def find_all_linear_names(model):
|
|||||||
return list(lora_module_names)
|
return list(lora_module_names)
|
||||||
|
|
||||||
|
|
||||||
def setup_quantized_meta_for_peft(model: nn.Module):
|
|
||||||
"""Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device"""
|
|
||||||
|
|
||||||
def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument
|
|
||||||
return self
|
|
||||||
|
|
||||||
for param in model.parameters():
|
|
||||||
if isinstance(param, Params4bit):
|
|
||||||
param.quant_state._orig_to = ( # pylint: disable=protected-access
|
|
||||||
param.quant_state.to
|
|
||||||
)
|
|
||||||
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_quantized_peft_meta_for_training(model: nn.Module):
|
|
||||||
"""Replaces dummy `quant_state.to` method with the original function to allow training to continue"""
|
|
||||||
for param in model.parameters():
|
|
||||||
if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"):
|
|
||||||
param.quant_state.to = (
|
|
||||||
param.quant_state._orig_to # pylint: disable=protected-access
|
|
||||||
)
|
|
||||||
param.quant_state._orig_to = None # pylint: disable=protected-access
|
|
||||||
|
|
||||||
|
|
||||||
def load_lora(model, cfg, inference=False, config_only=False):
|
def load_lora(model, cfg, inference=False, config_only=False):
|
||||||
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
|
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
|
||||||
|
|
||||||
@@ -1055,8 +832,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
lora_config_kwargs["init_lora_weights"] = "loftq"
|
lora_config_kwargs["init_lora_weights"] = "loftq"
|
||||||
if cfg.peft_use_dora:
|
if cfg.peft_use_dora:
|
||||||
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
||||||
if cfg.peft_use_rslora:
|
|
||||||
lora_config_kwargs["use_rslora"] = cfg.use_rslora
|
|
||||||
|
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
r=cfg.lora_r,
|
r=cfg.lora_r,
|
||||||
@@ -1074,11 +849,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
if config_only:
|
if config_only:
|
||||||
return None, lora_config
|
return None, lora_config
|
||||||
|
|
||||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
||||||
|
|
||||||
if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
|
|
||||||
setup_quantized_meta_for_peft(model)
|
|
||||||
|
|
||||||
if cfg.lora_model_dir:
|
if cfg.lora_model_dir:
|
||||||
LOG.debug("Loading pretrained PEFT - LoRA")
|
LOG.debug("Loading pretrained PEFT - LoRA")
|
||||||
model_kwargs: Any = {}
|
model_kwargs: Any = {}
|
||||||
@@ -1094,9 +864,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
else:
|
else:
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
if rank == 0:
|
model.print_trainable_parameters()
|
||||||
model.print_trainable_parameters()
|
|
||||||
elif cfg.fsdp and cfg.adapter == "qlora":
|
|
||||||
setup_quantized_peft_meta_for_training(model)
|
|
||||||
|
|
||||||
return model, lora_config
|
return model, lora_config
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ Multipack Batch Sampler
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Any, Iterable, List, Union
|
from typing import Any, Iterable, List, Union, Optional
|
||||||
|
|
||||||
import numba
|
import numba
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -115,12 +115,14 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
batch_max_len: int,
|
batch_max_len: int,
|
||||||
lengths: np.ndarray,
|
lengths: np.ndarray,
|
||||||
packing_efficiency_estimate: float = 1.0,
|
packing_efficiency_estimate: float = 1.0,
|
||||||
|
consistent_length: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
super().__init__(sampler, batch_size, drop_last)
|
super().__init__(sampler, batch_size, drop_last)
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.batch_max_len = batch_max_len
|
self.batch_max_len = batch_max_len
|
||||||
self.lengths: np.ndarray = lengths
|
self.lengths: np.ndarray = lengths
|
||||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||||
|
self.consistent_length = consistent_length
|
||||||
|
|
||||||
assert isinstance(self.lengths, np.ndarray)
|
assert isinstance(self.lengths, np.ndarray)
|
||||||
|
|
||||||
@@ -164,11 +166,18 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
batches = self.generate_batches(set_stats=True)
|
batches = self.generate_batches(set_stats=True)
|
||||||
return iter(batches)
|
if self.consistent_length:
|
||||||
|
length = self._len_est()
|
||||||
|
return iter(batches[:length])
|
||||||
|
else:
|
||||||
|
return iter(batches)
|
||||||
|
|
||||||
def num_batches(self):
|
def num_batches(self):
|
||||||
batches = self.generate_batches(set_stats=True)
|
batches = self.generate_batches(set_stats=True)
|
||||||
return len(batches)
|
if self.consistent_length:
|
||||||
|
return self._len_est()
|
||||||
|
else:
|
||||||
|
return len(batches)
|
||||||
|
|
||||||
def efficiency(self):
|
def efficiency(self):
|
||||||
return self.eff_total_used / self.eff_total_slots
|
return self.eff_total_used / self.eff_total_slots
|
||||||
|
|||||||
@@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
@@ -38,65 +36,3 @@ def check_example_labels(example, tokenizer, text_only=False):
|
|||||||
LOG.info("\n\n\n")
|
LOG.info("\n\n\n")
|
||||||
|
|
||||||
return " ".join(colored_tokens)
|
return " ".join(colored_tokens)
|
||||||
|
|
||||||
|
|
||||||
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
|
||||||
GLAIVE_TO_SHAREGPT_ROLE = {
|
|
||||||
"SYSTEM": "system",
|
|
||||||
"USER": "human",
|
|
||||||
"ASSISTANT": "gpt",
|
|
||||||
"FUNCTION RESPONSE": "tool",
|
|
||||||
}
|
|
||||||
|
|
||||||
GLAIVE_MSG_REGEX = re.compile(rf"({'|'.join(GLAIVE_ROLES)}): ")
|
|
||||||
|
|
||||||
|
|
||||||
def chatml_to_conversation(row: Dict[str, str]) -> List[Dict[str, str]]:
|
|
||||||
"""
|
|
||||||
Converts a ChatML formatted row to a list of messages in ShareGPT format.
|
|
||||||
Initially based off https://github.com/lilacai/lilac/blob/main/notebooks/GlaiveToShareGPT.ipynb.
|
|
||||||
"""
|
|
||||||
|
|
||||||
system_prompt = row.get("system")
|
|
||||||
if system_prompt:
|
|
||||||
system_prompt = system_prompt.removeprefix("SYSTEM: ")
|
|
||||||
|
|
||||||
chat_str = row["chat"]
|
|
||||||
chat_msgs = [s.strip() for s in GLAIVE_MSG_REGEX.split(chat_str) if s]
|
|
||||||
|
|
||||||
chat_msg_dicts = [
|
|
||||||
{"from": GLAIVE_TO_SHAREGPT_ROLE[role], "value": value}
|
|
||||||
for role, value in zip(chat_msgs[::2], chat_msgs[1::2])
|
|
||||||
]
|
|
||||||
|
|
||||||
if system_prompt:
|
|
||||||
chat_msg_dicts = [
|
|
||||||
{"from": GLAIVE_TO_SHAREGPT_ROLE["SYSTEM"], "value": system_prompt}
|
|
||||||
] + chat_msg_dicts
|
|
||||||
|
|
||||||
return chat_msg_dicts
|
|
||||||
|
|
||||||
|
|
||||||
def merge_consecutive_messages(messages):
|
|
||||||
"""
|
|
||||||
Merge consecutive messages from the same sender into a single message.
|
|
||||||
This can be useful with datasets that contain multiple consecutive tool calls.
|
|
||||||
"""
|
|
||||||
|
|
||||||
merged_messages = []
|
|
||||||
current_from = None
|
|
||||||
current_message = ""
|
|
||||||
|
|
||||||
for msg in messages:
|
|
||||||
if current_from == msg["from"]:
|
|
||||||
current_message += msg["value"]
|
|
||||||
else:
|
|
||||||
if current_from is not None:
|
|
||||||
merged_messages.append({"from": current_from, "value": current_message})
|
|
||||||
current_from = msg["from"]
|
|
||||||
current_message = msg["value"]
|
|
||||||
|
|
||||||
if current_from is not None:
|
|
||||||
merged_messages.append({"from": current_from, "value": current_message})
|
|
||||||
|
|
||||||
return merged_messages
|
|
||||||
|
|||||||
@@ -277,7 +277,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
calc_sample_packing_eff_est,
|
calc_sample_packing_eff_est,
|
||||||
)
|
)
|
||||||
sample_packing_eff_est = (
|
sample_packing_eff_est = (
|
||||||
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
math.ceil(sample_packing_actual_eff_all * 10000.0) / 10000.0
|
||||||
)
|
)
|
||||||
if update:
|
if update:
|
||||||
cfg.sample_packing_eff_est = sample_packing_eff_est
|
cfg.sample_packing_eff_est = sample_packing_eff_est
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Test module for sharegpt integration w chatml
|
Test module for sharegpt integration w chatml
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from tokenizers import AddedToken
|
from tokenizers import AddedToken
|
||||||
@@ -9,7 +8,6 @@ from transformers import AutoTokenizer
|
|||||||
|
|
||||||
from axolotl.datasets import TokenizedPromptDataset
|
from axolotl.datasets import TokenizedPromptDataset
|
||||||
from axolotl.prompt_strategies.sharegpt import (
|
from axolotl.prompt_strategies.sharegpt import (
|
||||||
GlaiveShareGPTPromptTokenizingStrategy,
|
|
||||||
SimpleShareGPTPromptTokenizingStrategy,
|
SimpleShareGPTPromptTokenizingStrategy,
|
||||||
register_chatml_template,
|
register_chatml_template,
|
||||||
)
|
)
|
||||||
@@ -50,18 +48,6 @@ def fixture_sharegpt_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="glaive_dataset")
|
|
||||||
def fixture_sharegpt_glaive_dataset():
|
|
||||||
return Dataset.from_list(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"system": "SYSTEM: This is a system prompt",
|
|
||||||
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="tokenizer")
|
@pytest.fixture(name="tokenizer")
|
||||||
def fixture_tokenizer():
|
def fixture_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||||
@@ -170,29 +156,3 @@ class TestSharegpt:
|
|||||||
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def test_chatml_glaive(self, glaive_dataset, tokenizer):
|
|
||||||
strategy = GlaiveShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(
|
|
||||||
conversation="chatml",
|
|
||||||
role_key_model=None,
|
|
||||||
role_key_human=None,
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
True, # train_on_inputs
|
|
||||||
2048, # sequence_len
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
|
||||||
strategy, glaive_dataset, process_count=1
|
|
||||||
)
|
|
||||||
|
|
||||||
labels = dataset_wrapper[0]["labels"]
|
|
||||||
# fmt: off
|
|
||||||
assert labels == [
|
|
||||||
1, # bos
|
|
||||||
32001, 1587, 13, 3260, 349, 264, 1587, 11510, 32000, 28705, 13, # system
|
|
||||||
32001, 2188, 13, 6325, 368, 1820, 264, 9314, 354, 528, 477, 1450, 2726, 298, 4222, 28804, 32000, 28705, 13, # human
|
|
||||||
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
|
|
||||||
]
|
|
||||||
# fmt: on
|
|
||||||
|
|||||||
@@ -1,285 +0,0 @@
|
|||||||
"""
|
|
||||||
This module contains unit tests for the `freeze_layers_except` function.
|
|
||||||
|
|
||||||
The `freeze_layers_except` function is used to freeze layers in a model, except for the specified layers.
|
|
||||||
The unit tests in this module verify the behavior of the `freeze_layers_except` function in different scenarios.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
|
||||||
|
|
||||||
ZERO = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
|
||||||
ONE_TO_TEN = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
|
||||||
|
|
||||||
|
|
||||||
class TestFreezeLayersExcept(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
A test case class for the `freeze_layers_except` function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.model = _TestModel()
|
|
||||||
|
|
||||||
def test_freeze_layers_with_dots_in_name(self):
|
|
||||||
freeze_layers_except(self.model, ["features.layer"])
|
|
||||||
self.assertTrue(
|
|
||||||
self.model.features.layer.weight.requires_grad,
|
|
||||||
"model.features.layer should be trainable.",
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
self.model.classifier.weight.requires_grad,
|
|
||||||
"model.classifier should be frozen.",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_freeze_layers_without_dots_in_name(self):
|
|
||||||
freeze_layers_except(self.model, ["classifier"])
|
|
||||||
self.assertFalse(
|
|
||||||
self.model.features.layer.weight.requires_grad,
|
|
||||||
"model.features.layer should be trainable.",
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
self.model.classifier.weight.requires_grad,
|
|
||||||
"model.classifier should be frozen.",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_freeze_layers_regex_patterns(self):
|
|
||||||
# The second pattern cannot match because only characters 'a' to 'c' are allowed after the word 'class', whereas it should be matching the character 'i'.
|
|
||||||
freeze_layers_except(self.model, [r"^features.[a-z]+.weight$", r"class[a-c]+"])
|
|
||||||
self.assertTrue(
|
|
||||||
self.model.features.layer.weight.requires_grad,
|
|
||||||
"model.features.layer should be trainable.",
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
self.model.classifier.weight.requires_grad,
|
|
||||||
"model.classifier should be frozen.",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_all_layers_frozen(self):
|
|
||||||
freeze_layers_except(self.model, [])
|
|
||||||
self.assertFalse(
|
|
||||||
self.model.features.layer.weight.requires_grad,
|
|
||||||
"model.features.layer should be frozen.",
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
self.model.classifier.weight.requires_grad,
|
|
||||||
"model.classifier should be frozen.",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_all_layers_unfrozen(self):
|
|
||||||
freeze_layers_except(self.model, ["features.layer", "classifier"])
|
|
||||||
self.assertTrue(
|
|
||||||
self.model.features.layer.weight.requires_grad,
|
|
||||||
"model.features.layer should be trainable.",
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
self.model.classifier.weight.requires_grad,
|
|
||||||
"model.classifier should be trainable.",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_freeze_layers_with_range_pattern_start_end(self):
|
|
||||||
freeze_layers_except(self.model, ["features.layer[1:5]"])
|
|
||||||
self.assertTrue(
|
|
||||||
self.model.features.layer.weight.requires_grad,
|
|
||||||
"model.features.layer should be trainable.",
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
self.model.classifier.weight.requires_grad,
|
|
||||||
"model.classifier should be frozen.",
|
|
||||||
)
|
|
||||||
|
|
||||||
self._assert_gradient_output(
|
|
||||||
[
|
|
||||||
ZERO,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_freeze_layers_with_range_pattern_single_index(self):
|
|
||||||
freeze_layers_except(self.model, ["features.layer[5]"])
|
|
||||||
self.assertTrue(
|
|
||||||
self.model.features.layer.weight.requires_grad,
|
|
||||||
"model.features.layer should be trainable.",
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
self.model.classifier.weight.requires_grad,
|
|
||||||
"model.classifier should be frozen.",
|
|
||||||
)
|
|
||||||
|
|
||||||
self._assert_gradient_output(
|
|
||||||
[ZERO, ZERO, ZERO, ZERO, ZERO, ONE_TO_TEN, ZERO, ZERO, ZERO, ZERO]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_freeze_layers_with_range_pattern_start_omitted(self):
|
|
||||||
freeze_layers_except(self.model, ["features.layer[:5]"])
|
|
||||||
self.assertTrue(
|
|
||||||
self.model.features.layer.weight.requires_grad,
|
|
||||||
"model.features.layer should be trainable.",
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
self.model.classifier.weight.requires_grad,
|
|
||||||
"model.classifier should be frozen.",
|
|
||||||
)
|
|
||||||
|
|
||||||
self._assert_gradient_output(
|
|
||||||
[
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_freeze_layers_with_range_pattern_end_omitted(self):
|
|
||||||
freeze_layers_except(self.model, ["features.layer[4:]"])
|
|
||||||
self.assertTrue(
|
|
||||||
self.model.features.layer.weight.requires_grad,
|
|
||||||
"model.features.layer should be trainable.",
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
self.model.classifier.weight.requires_grad,
|
|
||||||
"model.classifier should be frozen.",
|
|
||||||
)
|
|
||||||
|
|
||||||
self._assert_gradient_output(
|
|
||||||
[
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_freeze_layers_with_range_pattern_merge_included(self):
|
|
||||||
freeze_layers_except(self.model, ["features.layer[4:]", "features.layer[5:6]"])
|
|
||||||
self.assertTrue(
|
|
||||||
self.model.features.layer.weight.requires_grad,
|
|
||||||
"model.features.layer should be trainable.",
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
self.model.classifier.weight.requires_grad,
|
|
||||||
"model.classifier should be frozen.",
|
|
||||||
)
|
|
||||||
|
|
||||||
self._assert_gradient_output(
|
|
||||||
[
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_freeze_layers_with_range_pattern_merge_intersect(self):
|
|
||||||
freeze_layers_except(self.model, ["features.layer[4:7]", "features.layer[6:8]"])
|
|
||||||
self.assertTrue(
|
|
||||||
self.model.features.layer.weight.requires_grad,
|
|
||||||
"model.features.layer should be trainable.",
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
self.model.classifier.weight.requires_grad,
|
|
||||||
"model.classifier should be frozen.",
|
|
||||||
)
|
|
||||||
|
|
||||||
self._assert_gradient_output(
|
|
||||||
[
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_freeze_layers_with_range_pattern_merge_separate(self):
|
|
||||||
freeze_layers_except(
|
|
||||||
self.model,
|
|
||||||
["features.layer[1:2]", "features.layer[3:4]", "features.layer[5:6]"],
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
self.model.features.layer.weight.requires_grad,
|
|
||||||
"model.features.layer should be trainable.",
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
self.model.classifier.weight.requires_grad,
|
|
||||||
"model.classifier should be frozen.",
|
|
||||||
)
|
|
||||||
|
|
||||||
self._assert_gradient_output(
|
|
||||||
[
|
|
||||||
ZERO,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ZERO,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ZERO,
|
|
||||||
ONE_TO_TEN,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
ZERO,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _assert_gradient_output(self, expected):
|
|
||||||
input_tensor = torch.tensor([ONE_TO_TEN], dtype=torch.float32)
|
|
||||||
|
|
||||||
self.model.features.layer.weight.grad = None # Reset gradients
|
|
||||||
output = self.model.features.layer(input_tensor)
|
|
||||||
loss = output.sum()
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
expected_grads = torch.tensor(expected)
|
|
||||||
torch.testing.assert_close(
|
|
||||||
self.model.features.layer.weight.grad, expected_grads
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _SubLayerModule(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.layer = nn.Linear(10, 10)
|
|
||||||
|
|
||||||
|
|
||||||
class _TestModel(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.features = _SubLayerModule()
|
|
||||||
self.classifier = nn.Linear(10, 2)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Module for testing prompt tokenizers."""
|
"""Module for testing prompt tokenizers."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import unittest
|
import unittest
|
||||||
@@ -19,7 +18,6 @@ from axolotl.prompt_strategies.llama2_chat import (
|
|||||||
Llama2ChatPrompter,
|
Llama2ChatPrompter,
|
||||||
LLama2ChatTokenizingStrategy,
|
LLama2ChatTokenizingStrategy,
|
||||||
)
|
)
|
||||||
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
|
|
||||||
from axolotl.prompt_tokenizers import (
|
from axolotl.prompt_tokenizers import (
|
||||||
AlpacaPromptTokenizingStrategy,
|
AlpacaPromptTokenizingStrategy,
|
||||||
ShareGPTPromptTokenizingStrategy,
|
ShareGPTPromptTokenizingStrategy,
|
||||||
@@ -268,23 +266,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
idx = res["input_ids"].index(20255) # assistant token
|
idx = res["input_ids"].index(20255) # assistant token
|
||||||
assert res["labels"][idx] == -100
|
assert res["labels"][idx] == -100
|
||||||
|
|
||||||
def test_glaive_tool_label_ignore(self):
|
|
||||||
conversation = {
|
|
||||||
"system": "SYSTEM: This is a system prompt",
|
|
||||||
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
|
|
||||||
}
|
|
||||||
prompter = ShareGPTPrompterV2()
|
|
||||||
strat = GlaiveShareGPTPromptTokenizingStrategy(
|
|
||||||
prompter,
|
|
||||||
self.tokenizer,
|
|
||||||
False,
|
|
||||||
2048,
|
|
||||||
)
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
res = strat.tokenize_prompt(conversation)
|
|
||||||
idx = res["input_ids"].index(13566) # assistant token
|
|
||||||
assert res["labels"][idx] == -100
|
|
||||||
|
|
||||||
def test_no_sys_prompt(self):
|
def test_no_sys_prompt(self):
|
||||||
"""
|
"""
|
||||||
tests the interface between the user and assistant parts
|
tests the interface between the user and assistant parts
|
||||||
|
|||||||
Reference in New Issue
Block a user