Compare commits

...

26 Commits

Author SHA1 Message Date
Casper Hansen
10328b3429 Simplify creating parameters 2024-03-18 12:32:59 +00:00
Casper Hansen
5bfc470d57 Stop transformers from using all memory 2024-03-18 11:47:47 +00:00
Casper Hansen
04168801c9 Simplify conversion + more debug 2024-03-17 20:21:46 +00:00
Casper
d43a79b7bf device_map auto 2024-03-17 19:52:56 +01:00
Casper
884d81331e Initialize ParallelExperts on device of first expert 2024-03-17 19:51:31 +01:00
Casper
2ea75b4160 temporary: inference validation script 2024-03-17 19:48:52 +01:00
Casper Hansen
035e680631 Update test 2024-03-15 13:58:12 +00:00
Casper Hansen
26fc10df01 Refactor names, bugfixes 2024-03-15 12:39:11 +00:00
Casper Hansen
1bc008e901 Refactor creating FusedExperts 2024-03-15 11:59:56 +00:00
Casper Hansen
3f7ed6a784 Bugfixes, test green 2024-03-15 11:48:46 +00:00
Casper
feea977923 initial implementation, untested 2024-03-15 11:54:36 +01:00
Wing Lian
8df7b888ff beta support for multipack with gemmoe: (#1402) 2024-03-14 15:52:23 -04:00
Sebastian Raschka
6366b0c212 Fix Gemma 7b qlora.yml (#1405) 2024-03-14 15:44:38 -04:00
Seungduk Kim
05bcc9ea56 Train parameters exclusively in specific ranges (#1390)
* Train parameters exclusively in specific ranges

* Fix the style and update docs

* Update yaml example
2024-03-14 11:05:42 -04:00
Chirag Jain
3bd8203c35 Don't disable existing loggers when configuring axolotl logging (#1395) 2024-03-14 11:05:21 -04:00
Hamel Husain
8b12468230 Add QLoRA + FSDP Docs (#1403)
* pre commit

* Update fsdp_qlora.md
2024-03-14 11:04:51 -04:00
Chirag Jain
0976781e15 Update ChatTemplate enum to include alpaca and gemma (#1396) 2024-03-13 11:06:02 -04:00
Wing Lian
8a82d2e0a4 add handling for argilla dpo-mix (#1397) 2024-03-12 17:17:10 -04:00
Wing Lian
4326520829 chore: lint (#1389) 2024-03-10 21:02:55 -04:00
Brian Fitzgerald
b7d8a7dc4d Add Glaive conversation format support (#1365)
* Add Glaive conversation format support

* fix black formatting errors

* Fix black and pylint formatting errors

* only set role_key_tool if provided in the dataset constructor

* Update src/axolotl/prompt_strategies/sharegpt.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* sharegpt test

* tokenizer test

* fix formatting

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-03-10 20:50:25 -04:00
Seungduk Kim
b0ee9ec734 Set gradient_clipping to auto in DeepSpeed configs (#1382) [skip ci] 2024-03-10 20:50:12 -04:00
David Baker
0bc114d2e1 Fix pydantic configuration for the max_memory input (#1385) [skip ci]
* Fix pydantic configuration for the max_memory input

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-03-10 20:50:04 -04:00
Wing Lian
7659c001aa support for rslora (#1387) [skip ci] 2024-03-10 20:49:45 -04:00
Wing Lian
3fd8093717 validation for fsdp and deepspeed (#1388) [skip ci]
* validation for fsdp and deepspeed

* make sure to return data
2024-03-10 20:49:25 -04:00
Wing Lian
9b6ee83a73 FDSP + QLoRA (#1378)
* wip qlora + fsdp fixes

* more fixes

* make sure to load the lora 🤦

* only setup quantized meta on non-zero rank:

* only run setup_quantized_peft_meta_for_training for qlora+fsdp

* more fixes for qlora+fsdp

* chore: lint

* add example yml

* support mistral too

* fix for model_type and add mixtral support too

* set cpu_offload: false to reduce vram, constrain new accleerator logic to qlora + fsdp

* refactor for duplicate code
2024-03-08 14:31:01 -05:00
Wing Lian
638c2dafb5 JarvisLabs (#1372)
* add Jarvis cloud gpu and sponsorship

* whitespace
2024-03-07 10:47:32 -05:00
37 changed files with 2143 additions and 63 deletions

View File

@@ -25,7 +25,7 @@ Features:
- [Environment](#environment)
- [Docker](#docker)
- [Conda/Pip venv](#condapip-venv)
- [Cloud GPU](#cloud-gpu) - Latitude.sh, RunPod
- [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
- [Windows](#windows)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
@@ -199,6 +199,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
#### Bare Metal Cloud GPU
@@ -1079,6 +1080,10 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
##### FSDP + QLoRA
Axolotl supports training with FSDP and QLoRA, see [these docs](docs/fsdp_qlora.md) for more information.
##### Weights & Biases Logging
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
@@ -1298,4 +1303,6 @@ consider sponsoring the project via [GitHub Sponsors](https://github.com/sponsor
#### 🥉 Bronze Sponsors - $500/mo
- [JarvisLabs.ai](https://jarvislabs.ai)
---

View File

@@ -16,6 +16,7 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

View File

@@ -20,6 +20,7 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

View File

@@ -24,6 +24,7 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

View File

@@ -24,6 +24,7 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

37
docs/fsdp_qlora.md Normal file
View File

@@ -0,0 +1,37 @@
# FDSP + QLoRA
## Background
Using FSDP with QLoRA is essential for **fine-tuning larger (70b+ parameter) LLMs on consumer GPUs.** For example, you can use FSDP + QLoRA to train a 70b model on two 24GB GPUs[^1].
Below, we describe how to use this feature in Axolotl.
## Usage
To enable `QLoRA` with `FSDP`, you need to perform the following steps:
> ![Tip]
> See the [example config](#example-config) file in addition to reading these instructions.
1. Set `adapter: qlora` in your axolotl config file.
2. Enable FSDP in your axolotl config, as [described here](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#fsdp).
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
## Example Config
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.
## References
- [PR #1378](https://github.com/OpenAccess-AI-Collective/axolotl/pull/1378) enabling QLoRA in FSDP in Axolotl.
- [Blog Post](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the [Answer.AI](https://www.answer.ai/) team describing the work that enabled QLoRA in FSDP.
- Related HuggingFace PRs Enabling FDSP + QLoRA:
- Accelerate [PR#2544](https://github.com/huggingface/accelerate/pull/2544 )
- Transformers [PR#29587](https://github.com/huggingface/transformers/pull/29587)
- TRL [PR#1416](https://github.com/huggingface/trl/pull/1416)
- PEFT [PR#1550](https://github.com/huggingface/peft/pull/1550)
[^1]: This was enabled by [this work](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the Answer.AI team.

View File

@@ -21,7 +21,7 @@ lora_dropout: 0.05
lora_target_linear: true
sequence_len: 4096
sample_packing: true
sample_packing: false
pad_to_sequence_len: true
wandb_project:

View File

@@ -0,0 +1,70 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 512
sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 4
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
special_tokens:

View File

@@ -0,0 +1,74 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./qlora-out
model_config:
output_router_logits: true
adapter: qlora
lora_model_dir:
sequence_len: 1024
sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
special_tokens:

View File

@@ -16,12 +16,12 @@ output_dir: ./qlora-out
## You can optionally freeze the entire model and unfreeze a subset of parameters
unfrozen_parameters:
# - lm_head.*
# - model.embed_tokens.*
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
# - ^lm_head.weight$
# - ^model.embed_tokens.weight$[:32000]
# - model.layers.2[0-9]+.block_sparse_moe.gate
# - model.layers.2[0-9]+.block_sparse_moe.experts
# - model.layers.3[0-9]+.block_sparse_moe.gate
# - model.layers.3[0-9]+.block_sparse_moe.experts
model_config:
output_router_logits: true

View File

@@ -0,0 +1,75 @@
import gc
import torch
from tqdm import tqdm
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
from transformers import AutoTokenizer, TextStreamer
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralForCausalLM, MixtralConfig
def compute_memory_used_pct(device):
memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
memory_pct = (
memory_used
/ (torch.cuda.get_device_properties(device).total_memory / (1024**3))
* 100
)
return memory_pct
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
# Load model
config = MixtralConfig.from_pretrained(model_path, max_position_embeddings=2048, use_cache=False)
model = MixtralForCausalLM.from_pretrained(
model_path,
config=config,
device_map="auto",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
)
modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)}
for device_index in range(torch.cuda.device_count()):
device_memory_pct = compute_memory_used_pct(device_index)
print(device_index, device_memory_pct)
with tqdm(modules.items(), desc="scatter moe") as pbar:
for i, (name, module) in enumerate(pbar):
smoe = SparseMoeBlock(
experts=module.experts,
gate=module.gate,
hidden_dim=module.hidden_dim,
ffn_dim=module.ffn_dim,
num_experts=module.num_experts,
top_k=module.top_k,
)
old_module = model.model.layers[i].block_sparse_moe
setattr(model.model.layers[i], "block_sparse_moe", smoe)
del old_module
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
for device_index in range(torch.cuda.device_count()):
device_memory_pct = compute_memory_used_pct(device_index)
print(device_index, device_memory_pct)
tokenizer = AutoTokenizer.from_pretrained(model_path)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Convert prompt to tokens
prompt_template = "[INST] {prompt} [/INST]"
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?"
tokens = tokenizer(
prompt_template.format(prompt=prompt),
return_tensors='pt'
).input_ids.cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)

View File

@@ -3,7 +3,7 @@ packaging==23.2
peft==0.9.0
transformers==4.38.2
tokenizers==0.15.0
bitsandbytes>=0.41.1
bitsandbytes>=0.43.0
accelerate==0.26.1
deepspeed==0.13.1
pydantic==2.6.3
@@ -40,3 +40,4 @@ gcsfs
# adlfs
trl>=0.7.9
fastcore>=1.5.29

View File

View File

@@ -0,0 +1,55 @@
"""module for building the auto wrap policy for FSDP"""
import functools
from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
from torch.distributed.fsdp.wrap import (
_or_policy,
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
"llama",
"mistral",
"mixtral",
]
def get_wrapping_policy_factory(model_type):
if model_type == "llama":
layer_to_wrap = LlamaDecoderLayer
elif model_type == "mistral":
layer_to_wrap = MistralDecoderLayer
elif model_type == "mixtral":
layer_to_wrap = MixtralDecoderLayer
def get_wrapping_policy():
"""This checks for lora layers (has weight and requires_grad)"""
def lambda_policy_fn(module):
return (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
)
lambda_policy = functools.partial(
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
)
transformer_layer_name = layer_to_wrap
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(
PrefixEncoder,
PromptEncoder,
PromptEmbedding,
transformer_layer_name,
),
)
policies = [lambda_policy, transformer_wrap_policy]
return functools.partial(_or_policy, policies=policies)
return get_wrapping_policy

View File

@@ -8,6 +8,7 @@ import importlib
import importlib.util
import logging
import math
import os
import sys
from abc import abstractmethod
from dataclasses import dataclass, field
@@ -17,7 +18,10 @@ from typing import List, Optional, Type, Union
import torch
import transformers
from accelerate import FullyShardedDataParallelPlugin
from accelerate.utils import str_to_bool
from datasets import Dataset
from torch.distributed.fsdp import MixedPrecision
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
@@ -30,6 +34,7 @@ from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
@@ -191,6 +196,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
class AxolotlTrainer(Trainer):
@@ -468,6 +477,56 @@ class AxolotlTrainer(Trainer):
return super().push_to_hub(*args, **kwargs)
@wraps(Trainer.create_accelerator_and_postprocess)
def create_accelerator_and_postprocess(self):
rank = int(os.environ.get("LOCAL_RANK", 0))
res = super().create_accelerator_and_postprocess()
if self.args.qlora is False:
return res
# the rest of this method override is specific to fsdp + qlora (for now)
sync_module_states = (
str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
)
mp_policy = None
amp = os.environ["ACCELERATE_MIXED_PRECISION"]
if amp == "fp16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
elif amp == "bf16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
# If somehow we figure out how we want to parameterize we want to autocast buffers...
# mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
# load_param_skip_names = ['inv_freq']
if self.is_fsdp_enabled:
wrapping_policy = get_wrapping_policy_factory(self.args.model_type)
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrapping_policy(),
cpu_offload=False,
use_orig_params=False,
limit_all_gathers=True,
param_init_fn=lambda module: module.to_empty(
device=torch.device("cuda"), recurse=False
)
if (rank != 0 and sync_module_states)
else None,
mixed_precision_policy=mp_policy,
)
self.accelerator.state.fsdp_plugin = fsdp_plugin
return res
class AxolotlMambaTrainer(AxolotlTrainer):
"""
@@ -787,6 +846,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True
# deepspeed
if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed

View File

@@ -30,6 +30,7 @@ class ColorfulFormatter(Formatter):
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"simple": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",

View File

View File

@@ -0,0 +1,149 @@
"""
Adapted from:
https://github.com/shawntan/scattermoe
https://arxiv.org/abs/2403.08245
"""
import torch
import torch.nn as nn
from axolotl.monkeypatch.moe import ops
class ParallelLinear(torch.autograd.Function):
@staticmethod
def forward(
ctx, x, expert_weights, k,
sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
gates=None, grouped_in=False, grouped_out=False,
):
output = ops.scatter2scatter(
X=x, W=expert_weights,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
padded_block_idxs=padded_block_idxs,
k=k, x_grouped=grouped_in, y_grouped=grouped_out
)
if gates is not None:
output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1))
output = torch.bmm(
gates[:, None, :],
output_expanded
).squeeze(1)
else:
output_expanded = None
ctx.save_for_backward(
x, expert_weights,
sorted_expert_idxs,
sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
gates,
output_expanded
)
ctx.grouped_in = grouped_in
ctx.grouped_out = grouped_out
ctx.k = k
return output
@staticmethod
def backward(ctx, grad_out):
(x, expert_weights,
sorted_expert_idxs,
sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
gates, output_expanded) = ctx.saved_tensors
k = ctx.k
grouped_in = ctx.grouped_in
grouped_out = ctx.grouped_out
# print("backward")
if gates is not None:
# calculate gates gradient
d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
gates_flat = gates.flatten()
gate_fan = gates.size(1)
# print("expanded and grouping")
grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later
else:
d_gates = None
gates_flat = None
gate_fan = 1
grouped_grad_out = None
if grouped_out:
grouped_grad_out = grad_out
else:
grouped_grad_out = ops.group(grad_out, sorted_scattered_idxs,
fan_out=gate_fan, coeff=gates_flat,
out=grouped_grad_out)
if grouped_in:
grouped_x = x
d_expanded_input = None
else:
grouped_x = ops.group(x, sorted_scattered_idxs, fan_out=k)
d_expanded_input = grouped_x
d_weights = ops.group_bwd_W(
DY=grouped_grad_out, X=grouped_x,
expert_offsets=expert_offsets,
E=expert_weights.size(0)
)
d_expanded_input = ops.scatter2scatter(
X=grouped_grad_out, x_grouped=True,
W=expert_weights.permute(0, 2, 1),
padded_block_idxs=padded_block_idxs,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
y_grouped=grouped_in,
out=d_expanded_input # Reuse grouped_x buffer
)
if k == 1:
d_input = d_expanded_input
else:
d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2)
# print("backward end.")
return (
# x, expert_weights, k,
d_input, d_weights, None,
# sorted_expert_idxs, sorted_scattered_idxs,
None, None,
# padded_block_idxs, expert_offsets,
None, None,
# gates
d_gates, None, None
)
def parallel_linear(inputs, expert_weights, k,
sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
gates=None):
results = ParallelLinear.apply(inputs, expert_weights, k,
sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets, gates)
return results
class ParallelExperts(nn.Module):
def __init__(self, num_experts, input_size, output_size, device) -> None:
super().__init__()
self.weight = nn.Parameter(
torch.empty(num_experts, output_size, input_size, device=device)
)
self.num_experts = num_experts
self.input_size = input_size
self.output_size = output_size
def extra_repr(self):
return 'num_experts={}, input_size={}, output_size={}'.format(
self.num_experts, self.input_size, self.output_size)
def forward(self, inputs, k, sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
gates=None, grouped_in=False, grouped_out=False):
results = ParallelLinear.apply(
inputs, self.weight.permute(0, 2, 1), k,
sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
gates, grouped_in, grouped_out
)
return results

View File

@@ -0,0 +1,86 @@
"""
Adapted from:
https://github.com/shawntan/scattermoe
https://arxiv.org/abs/2403.08245
"""
import gc
import torch
from torch import nn
from axolotl.monkeypatch.moe import ops
from axolotl.monkeypatch.moe.linear import ParallelExperts
class FusedExperts(nn.Module):
def __init__(
self,
experts: nn.ModuleList =None,
hidden_dim=128,
ffn_dim=512,
num_experts=8,
top_k=2,
activation=nn.SiLU(),
):
"""
This implements fused experts that are compatible with Mixtral.
MLP of type Gated-Linear Unit, typically with a SiLU activation function.
"""
super(FusedExperts, self).__init__()
device = experts[0].w1.weight.device
self.num_experts = num_experts
self.hidden_dim = hidden_dim
self.ffn_dim = ffn_dim
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim, device=device)
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim, device=device)
self.top_k = min(top_k, self.num_experts)
self.activation = activation
with torch.no_grad():
for i in range(len(experts)):
self.experts.weight.data[i].copy_(
torch.cat(
[experts[i].w1.weight.detach(), experts[i].w3.weight.detach()],
dim=0
)
)
self.output_experts.weight.data[i].copy_(
experts[i].w2.weight.detach()
)
def forward(
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor
):
x_shape = x.size()
x = x.view(-1, x_shape[-1])
with torch.no_grad():
sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort(
selected_experts
)
padded_block_idxs, expert_offsets = ops.padded_block_indices(
sorted_expert_idxs, self.num_experts
)
h, gates = self.experts(
x,
self.top_k,
sorted_expert_idxs,
sorted_scattered_idxs,
padded_block_idxs,
expert_offsets,
grouped_out=True,
).chunk(2, dim=-1)
h = self.activation(gates) * h
y = self.output_experts(
h,
1,
sorted_expert_idxs,
sorted_scattered_idxs,
padded_block_idxs,
expert_offsets,
grouped_in=True,
gates=routing_weights,
)
y = y.view(*x_shape[:-1], y.size(-1))
return y

View File

@@ -0,0 +1,50 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from axolotl.monkeypatch.moe.mlp import FusedExperts
class SparseMoeBlock(nn.Module):
def __init__(self, experts, gate, hidden_dim, ffn_dim, num_experts, top_k):
super().__init__()
self.hidden_dim = hidden_dim
self.ffn_dim = ffn_dim
self.num_experts = num_experts
self.top_k = top_k
self.gate = gate
self.experts = FusedExperts(
experts=experts,
hidden_dim=hidden_dim,
ffn_dim=ffn_dim,
num_experts=num_experts,
top_k=top_k,
activation=experts[0].act_fn
)
def _post_training(self, model, name):
# get original weights back: reverse the concat + stack in the fused experts
w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1)
w2s = torch.unbind(self.experts.output_experts.weight, dim=0)
# TODO: recreate MoE class with original weights
experts = []
for i in range(self.num_experts):
pass
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
# Fused expert forward
final_hidden_states = self.experts(hidden_states, routing_weights, selected_experts)
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits

View File

@@ -0,0 +1,353 @@
"""
Adapted from:
https://github.com/shawntan/scattermoe
https://arxiv.org/abs/2403.08245
"""
import torch
import triton
import triton.language as tl
from torch.nn import functional as F
BLOCK_M = 128
@torch.jit.script
def flatten_and_sort(expert_idxs:torch.Tensor):
flattened_expert_idxs = expert_idxs.flatten()
sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)
return sorted_expert_idxs, sorted_scattered_idxs
@torch.jit.script
def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int=BLOCK_M) :
expert_counts = torch.bincount(sorted_experts_idxs, minlength=k)
padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1
padded_expert_block_end = padded_block_counts.cumsum(-1)
expert_boundaries_end = expert_counts.cumsum(-1)
expert_boundaries_start = expert_boundaries_end - expert_counts
padded_expert_block_start = padded_expert_block_end - padded_block_counts
block_idxs = torch.arange(padded_expert_block_end[-1],
dtype=sorted_experts_idxs.dtype,
device=sorted_experts_idxs.device)
block_mask = (
(block_idxs[:, None] < padded_expert_block_start) |
(block_idxs[:, None] >= padded_expert_block_end)
)
expanded_block_idxs = (
N_BLOCK_SIZE * (block_idxs[:, None] - padded_expert_block_start) +
expert_boundaries_start
)
expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1)
return expanded_block_idxs, expert_boundaries_end
def _scatter2scatter_configs():
return [
triton.Config({'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
]
@triton.autotune(configs=_scatter2scatter_configs(), key=['M', 'N', 'K'], )
@triton.heuristics({
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
"NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
})
@triton.jit
def _scatter2scatter(
X_ptr, stride_xm, stride_xk,
W_ptr, stride_we, stride_wk, stride_wn,
Y_ptr, stride_ym, stride_yn,
grouped_idx_ptr, expert_idxs_ptr, block_start_idx_ptr,
FAN_OUT: tl.constexpr,
M: tl.constexpr, K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
OUT_M: tl.constexpr,
allow_tf32: tl.constexpr,
x_grouped: tl.constexpr, y_grouped: tl.constexpr,
NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
):
pid = tl.program_id(axis=0)
N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)
M_block_id = pid // N_BLOCK_COUNT
N_block_id = pid % N_BLOCK_COUNT
M_range = tl.arange(0, BLOCK_M)
block_start_idx = tl.load(block_start_idx_ptr + M_block_id)
# M_block = tl.max_contiguous((block_start_idx + M_range) % OUT_M, BLOCK_M)
M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M)
E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E)
E_idx = tl.min(E_idxs)
E_mask = E_idxs == E_idx
M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0)
if x_grouped:
M_in_idx = M_block
else:
M_in_idx = M_idx // FAN_OUT
if y_grouped:
M_out_idx = M_block
else:
M_out_idx = M_idx
K_block = tl.arange(0, BLOCK_K)
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_block < N
# N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
# N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
iters = tl.cdiv(K, BLOCK_K)
for K_block_id in range(0, iters):
if NO_K_MASK:
x = tl.load(X_blk_ptrs, mask=E_mask[:, None])
if NO_N_MASK:
w = tl.load(W_blk_ptrs)
else:
w = tl.load(W_blk_ptrs, mask=N_mask[None, :])
else:
K_mask = (K_block_id * BLOCK_K + K_block) < K
x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])
X_blk_ptrs += BLOCK_K * stride_xk
W_blk_ptrs += BLOCK_K * stride_wk
acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE)
Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :])
def scatter2scatter(X, W, sorted_expert_idxs, sorted_scattered_idxs, k,
padded_block_idxs, x_grouped=False, y_grouped=False,
out=None):
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
# Pre-kernel setup
x_dim = X.size(-1)
y_dim = W.size(-1)
L_scattered = sorted_expert_idxs.size(0)
if out is None:
O = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)
else:
assert out.size(0) == L_scattered and out.size(1) == y_dim
O = out
def grid(META):
grid_num = (
padded_block_idxs.size(0) *
triton.cdiv(META['N'], META['BLOCK_N']),
)
return grid_num
"""
print("X", X.size(), X.stride(),
"W", W.size(), W.stride(),
"O", O.size(), O.stride(),
"sorted_idxs", sorted_scattered_idxs.size(),
"FAN_OUT", k,
"BLOCK_M", BLOCK_M,
"grouped", (x_grouped, y_grouped))
"""
_scatter2scatter[grid](
# X_ptr, stride_xm, stride_xk,
X, X.stride(0), X.stride(1),
# W_ptr, stride_we, stride_wk, stride_wn,
W, W.stride(0), W.stride(1), W.stride(2),
# Y_ptr, stride_ym, stride_yn,
O, O.stride(0), O.stride(1),
grouped_idx_ptr=sorted_scattered_idxs,
expert_idxs_ptr=sorted_expert_idxs,
block_start_idx_ptr=padded_block_idxs,
FAN_OUT=k,
M=X.size(0),
K=X.size(1),
N=O.size(1), E=W.size(0),
BLOCK_M=BLOCK_M,
ACC_TYPE=tl.float32,
OUT_M=O.size(0),
allow_tf32=True,
x_grouped=x_grouped, y_grouped=y_grouped,
)
return O
def _config_XtY():
return [
triton.Config({'BLOCK_N': 128, 'BLOCK_K': 128, 'BLOCK_M': 32}, num_stages=4, num_warps=4),
]
def group_bwd_W(DY, X, expert_offsets, E):
DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)
DW = DWt.permute(0, 2, 1)
def grid(META):
grid = (
E * triton.cdiv(META['K'], META['BLOCK_K']),
triton.cdiv(META['N'], META['BLOCK_N']),
)
return grid
_groupXtY[grid](
# DY_ptr, stride_dym, stride_dyk,
DY, DY.stride(0), DY.stride(1),
# X_ptr, stride_xm, stride_xn,
X, X.stride(0), X.stride(1),
# DW_ptr, stride_dwe, stride_dwk, stride_dwn,
DW, DW.stride(0), DW.stride(1), DW.stride(2),
# expert_offsets_ptr,
expert_offsets,
# K: tl.constexpr, N: tl.constexpr,
M=DY.size(0), N=DY.size(-1), K=X.size(-1),
# ACC_TYPE: tl.constexpr,
ACC_TYPE=tl.float32,
allow_tf32=True
)
return DW
@triton.autotune(configs=_config_XtY(), key=['M', 'N', 'K'], )
@triton.heuristics({
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
"NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
})
@triton.jit
def _groupXtY(
DY_ptr, stride_dym, stride_dyk,
X_ptr, stride_xm, stride_xn,
DW_ptr, stride_dwe, stride_dwk, stride_dwn,
expert_offsets_ptr,
M: tl.constexpr, K: tl.constexpr, N: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
allow_tf32: tl.constexpr,
NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
):
pid0 = tl.program_id(axis=0)
pid1 = tl.program_id(axis=1)
num0 = tl.num_programs(0)
num1 = tl.num_programs(1)
pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)
K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
E_idx = pid0 // K_BLOCK_COUNT
K_block_id = pid0 % K_BLOCK_COUNT
N_block_id = pid1
if E_idx == 0:
start_idx = 0
else:
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
if end_idx > start_idx:
M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)
K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
K_mask = K_block < K
K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_block < N
N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
M_idxs = M_block
xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
for i in range(0, iters):
M_mask = (i * BLOCK_M + M_block) < end_idx
if NO_K_MASK:
xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
else:
xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
if NO_N_MASK:
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
else:
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)
xt_blk_ptrs += BLOCK_M * stride_xm
dy_blk_ptrs += BLOCK_M * stride_dym
DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn
acc = acc.to(DW_blk_ptrs.dtype.element_ty)
tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])
def _config_grouping():
return [
triton.Config({'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
]
def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):
N = sorted_expert_idxs.size(0)
K = A.size(1)
assert A.size(0) * fan_out == N
if out is not None:
Y = out
else:
Y = torch.empty((N, K), dtype=A.dtype, device=A.device)
# print("grp init:", Y.size())
def grid(META):
grid_num = (triton.cdiv(META['N'], META['BLOCK_N']),)
return grid_num
_group[grid](
# A_ptr, stride_an, stride_ai,
A, A.stride(0), A.stride(1), coeff is not None, coeff, fan_out,
# Y_ptr, stride_yn, stride_yk,
Y, Y.stride(0), Y.stride(1),
# grouped_idx_ptr,
sorted_expert_idxs,
# N: tl.constexpr, K: tl.constexpr,
N, K
)
return Y
@triton.autotune(configs=_config_grouping(), key=['K'])
@triton.heuristics({
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0
})
@triton.jit
def _group(
src_ptr, stride_sn, stride_sk, has_coeff: tl.constexpr, coeff_ptr, FAN_OUT: tl.constexpr,
tgt_ptr, stride_tn, stride_ti,
grouped_idx_ptr,
N: tl.constexpr, K: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
NO_K_MASK: tl.constexpr
):
pid = tl.program_id(axis=0)
N_block_id = pid
N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_blk < N
N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)
N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)
K_blk = tl.arange(0, BLOCK_K)
src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
if has_coeff:
c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
iters = tl.cdiv(K, BLOCK_K)
for i in range(0, iters):
if NO_K_MASK:
block = tl.load(src_blk_ptrs) # , mask=N_mask[:, None])
if has_coeff:
block *= c
tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])
else:
K_mask = (i * BLOCK_K + K_blk) < K
mask = N_mask[:, None] & K_mask[None, :]
block = tl.load(src_blk_ptrs, mask=mask)
if has_coeff:
block *= c
tl.store(tgt_blk_ptrs, block, mask=mask)
src_blk_ptrs += BLOCK_K * stride_sk
tgt_blk_ptrs += BLOCK_K * stride_ti

View File

@@ -0,0 +1,66 @@
"""
Adapted from:
https://github.com/shawntan/scattermoe
https://arxiv.org/abs/2403.08245
"""
import torch
import triton
import triton.language as tl
from torch.nn import functional as F
@triton.jit
def _single2scatter(
X_ptr, stride_xm, stride_xk,
W_ptr, stride_we, stride_wk, stride_wn,
Y_ptr, stride_ym, stride_yn,
expert_idxs_ptr,
FAN_OUT: tl.constexpr,
K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
pid0 = tl.program_id(axis=0)
pid1 = tl.program_id(axis=1)
N_block_id = pid0
if FAN_OUT == 1:
in_idx = pid1
else:
in_idx = 0
out_idx = pid1
K_block = tl.arange(0, BLOCK_K)
N_block = tl.max_contiguous(tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), BLOCK_N)
E_idx = tl.load(expert_idxs_ptr + pid1)
X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
W_blk_ptrs = W_ptr + E_idx * stride_we + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn
acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
for K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
x = tl.load(X_blk_ptrs)
w = tl.load(W_blk_ptrs)
acc += tl.sum(x * w, axis=0)[None, :]
X_blk_ptrs += BLOCK_K * stride_xk
W_blk_ptrs += BLOCK_K * stride_wk
Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
tl.store(Y_blk_ptrs, acc)
def single2scatter(X, W, expert_idxs):
E, xdim, ydim = W.size()
k = expert_idxs.size(1)
assert X.size(0) == k or X.size(0) == 1
Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
BLOCK_N = 128
BLOCK_K = 128
grid = ydim // BLOCK_N, k
_single2scatter[grid](
X, X.stride(0), X.stride(1),
W, W.stride(0), W.stride(1), W.stride(2),
Y, Y.stride(0), Y.stride(1),
expert_idxs,
FAN_OUT=Y.size(0) // X.size(0),
K=xdim, N=ydim, E=E,
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
ACC_TYPE=tl.float32
)
return Y

View File

@@ -1,6 +1,9 @@
"""multipack patching for v2 of sample packing"""
import importlib
import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.integrations import is_deepspeed_zero3_enabled
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
@@ -12,11 +15,12 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"falcon",
"phi",
"gemma",
"gemmoe",
"starcoder2",
]
def patch_for_multipack(model_type):
def patch_for_multipack(model_type, model_name=None):
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
@@ -43,3 +47,15 @@ def patch_for_multipack(model_type):
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemmoe":
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_gemmoe to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_gemmoe", ".modeling_gemmoe"
)
modeling_gemmoe = importlib.import_module(module_name)
modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -24,6 +24,25 @@ def argilla(
return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/dpo-mix-7k conversations
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample
return transform_fn
def icr(
cfg,
**kwargs,

View File

@@ -1,10 +1,15 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
from typing import Any, Dict, Optional
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
from axolotl.utils.tokenization import (
chatml_to_conversation,
merge_consecutive_messages,
)
def register_chatml_template(system_message=None):
@@ -19,6 +24,16 @@ def register_chatml_template(system_message=None):
sep="<|im_end|>",
)
)
register_conv_template(
Conversation(
name="chatml_glaive",
system_template="<|im_start|>system\n{system_message}",
system_message=system_message,
roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"],
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
@@ -77,6 +92,20 @@ def load_guanaco(tokenizer, cfg):
)
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else "chatml_glaive"
)
return GlaiveShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(conversation=conversation),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
@@ -158,3 +187,15 @@ class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingSt
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
]
return turns
class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps glaive data to sharegpt format
"""
def get_conversation_thread(self, prompt):
conversation = chatml_to_conversation(prompt)
conversation = merge_consecutive_messages(conversation)
return conversation

View File

@@ -360,11 +360,19 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
LOG.warning(f"expected tuple, got {part}")
continue
user, assistant = conversation.roles
tool_role_label = None
if len(conversation.roles) == 3:
(
user_role_label,
assistant_role_label,
tool_role_label,
) = conversation.roles
else:
user_role_label, assistant_role_label = conversation.roles
role, content = part
# Uses "in" because role contains extra characters
if user in role:
if user_role_label in role:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
@@ -384,7 +392,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif assistant in role:
elif assistant_role_label in role:
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
@@ -426,6 +434,8 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif tool_role_label and tool_role_label in role:
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue

View File

@@ -267,6 +267,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
role_key_human = "human"
role_key_model = "gpt"
# Optional, only used for tool usage datasets.
role_key_tool = None
def __init__(
self,
@@ -274,6 +276,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
role_key_tool: Optional[str] = None,
):
if conversation:
if isinstance(conversation, Conversation):
@@ -286,6 +289,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
self.role_key_human = role_key_human
if role_key_model:
self.role_key_model = role_key_model
if role_key_tool:
self.role_key_tool = role_key_tool
def _build_result(self, source):
if len(source) < 2:
@@ -303,6 +308,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
source.pop(0)
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
if self.role_key_tool:
roles[self.role_key_tool] = conv.roles[2]
try:
# Apply prompt templates

View File

@@ -19,7 +19,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_parameters_except
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
@@ -99,7 +99,7 @@ def train(
safe_serialization = cfg.save_safetensors is True
if cfg.unfrozen_parameters:
freeze_parameters_except(model, cfg.unfrozen_parameters)
freeze_layers_except(model, cfg.unfrozen_parameters)
trainer = setup_trainer(
cfg,

View File

@@ -24,9 +24,9 @@ def check_cuda_device(default_value):
or not torch.cuda.is_available()
or device == "auto"
or torch.device(device).type == "cpu"
or torch.device(device).type == "meta"
):
return default_value
return func(*args, **kwargs)
return wrapper

View File

@@ -1,6 +1,7 @@
"""
Module for pydantic models for configuration
"""
# pylint: disable=too-many-lines
import logging
import os
@@ -128,8 +129,10 @@ class RLType(str, Enum):
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
inst = "inst" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
class LoftQConfig(BaseModel):
@@ -179,6 +182,7 @@ class LoraConfig(BaseModel):
peft_layers_to_transform: Optional[List[int]] = None
peft: Optional[PeftConfig] = None
peft_use_dora: Optional[bool] = None
peft_use_relora: Optional[bool] = None
lora_on_cpu: Optional[bool] = None
gptq: Optional[bool] = None
@@ -511,10 +515,12 @@ class AxolotlInputConfig(
neftune_noise_alpha: Optional[float] = None
max_memory: Optional[Union[int, str]] = None
max_memory: Optional[
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
] = None
gpu_memory_limit: Optional[Union[int, str]] = None
chat_template: Optional[Union[Literal["chatml", "inst"], ChatTemplate]] = None
chat_template: Optional[ChatTemplate] = None
default_system_message: Optional[str] = None
# INTERNALS - document for now, generally not set externally
@@ -989,3 +995,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_deepspeed(cls, data):
if data.get("deepspeed") and data.get("fsdp"):
raise ValueError("deepspeed and fsdp cannot be used together.")
return data

View File

@@ -3,13 +3,14 @@ module to freeze/unfreeze parameters by name
"""
import logging
import re
from typing import Callable, List, Tuple
from axolotl.utils.distributed import is_main_process
LOG = logging.getLogger("axolotl.utils.freeze")
def freeze_parameters_except(model, regex_patterns):
def freeze_layers_except(model, regex_patterns):
"""
Freezes all layers of the given model except for the layers that match given regex patterns.
Periods in the patterns are treated as literal periods, not as wildcard characters.
@@ -17,22 +18,209 @@ def freeze_parameters_except(model, regex_patterns):
Parameters:
- model (nn.Module): The PyTorch model to be modified.
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
Note that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names.
Also, to match the entire layer name, the pattern should start with "^" and end with "$", otherwise it will match any part of the layer name.
The range pattern part is optional and it is not compiled as a regex pattern which means you must put "$" before the range pattern if you want to match the entire layer name.
E.g., ["^model.embed_tokens.weight$[:32000]", "layers.2[0-9]+.block_sparse_moe.gate.[a-z]+$"]
Returns:
None; the model is modified in place.
"""
# Escape periods and compile the regex patterns
compiled_patterns = [
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
]
if isinstance(regex_patterns, str):
regex_patterns = [regex_patterns]
# First, freeze all parameters in the model
for param in model.parameters():
param.requires_grad = False
patterns = [LayerNamePattern(pattern) for pattern in regex_patterns]
# Unfreeze layers that match the regex patterns
for name, param in model.named_parameters():
if any(pattern.match(name) for pattern in compiled_patterns):
if is_main_process():
LOG.debug(f"unfreezing {name}")
param.requires_grad = False
unfrozen_ranges = []
for pattern in patterns:
if not pattern.match(name):
continue
param.requires_grad = True
if pattern.range is not None:
unfrozen_ranges.append(pattern.range)
merged_unfrozen_ranges = _merge_ranges(unfrozen_ranges, len(param))
if param.requires_grad and is_main_process():
unfrozen_ranges = (
f" with ranges {merged_unfrozen_ranges}"
if merged_unfrozen_ranges
else ""
)
LOG.debug(f"Unfrozen {name}{unfrozen_ranges}")
if not merged_unfrozen_ranges:
continue
# The range list we need is actually the inverted of the merged ranges
ranges_to_freeze = _invert_ranges(merged_unfrozen_ranges, len(param))
param.register_hook(_create_freeze_parameters_hook(ranges_to_freeze))
if is_main_process() and all(
not param.requires_grad for param in model.parameters()
):
LOG.warning("All parameters are frozen. Model will not be trained.")
def _invert_ranges(
given_ranges: List[Tuple[int, int]], layer_size: int
) -> List[Tuple[int, int]]:
"""
Inverts a list of ranges to obtain the ranges not covered by the given ranges.
Parameters:
- given_ranges (List[Tuple[int, int]]): List of ranges to invert. Each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
Returns:
- List[Tuple[int, int]]: List of inverted ranges, where each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
"""
if not given_ranges:
return [(0, layer_size)]
inverted_ranges = []
current_start = 0
for start, end in sorted(given_ranges):
if start > current_start:
inverted_ranges.append((current_start, start))
current_start = max(current_start, end)
# Handle the case where the last given range does not reach the end of the total_size
if current_start < layer_size:
inverted_ranges.append((current_start, layer_size))
return inverted_ranges
def _merge_ranges(
given_ranges: List[Tuple[int, int | None]], layer_size: int
) -> List[Tuple[int, int]]:
"""
Merges overlapping ranges and sorts the given ranges.
This function takes a list of ranges and merges any overlapping ranges. The ranges are represented
as tuples, where the first element is the start index (inclusive) and the second element is the end
index (exclusive). The end index can be None, indicating that the range extends to the end of the
sequence.
Parameters:
- given_ranges (List[Tuple[int, int | None]]): List of ranges to merge.
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
Returns:
- List[Tuple[int, int]]: List of merged ranges, as start (inclusive) and end (exclusive) indices.
"""
# End of each range can be determined now since we have the total size
processed_ranges = [
(start, end if end is not None else layer_size) for start, end in given_ranges
]
# No need to merge if there's only one or no ranges
if len(processed_ranges) <= 1:
return processed_ranges
sorted_ranges = sorted(processed_ranges)
merged_ranges = [sorted_ranges[0]]
for start, end in sorted_ranges[1:]:
prev_start, prev_end = merged_ranges[-1]
if start <= prev_end:
merged_ranges[-1] = (prev_start, max(prev_end, end))
else:
merged_ranges.append((start, end))
return merged_ranges
def _create_freeze_parameters_hook(ranges_to_freeze: List[Tuple[int, int]]) -> Callable:
"""
Create a hook to freeze parameters in specified ranges by setting their gradients to zero.
This function takes a list of tuples representing the ranges of indices to freeze. Each tuple should contain
two integers representing the start and end indices of the range.
Parameters:
- ranges_to_freeze (List[Tuple[int, int]]): Ranges of indices to freeze.
Returns:
- Callable: A hook function to be used with `register_hook` on parameters.
Example usage:
```
ranges_to_freeze = [(0, 10), (20, 30)]
hook = _create_freeze_parameters_hook(ranges_to_freeze)
model.register_hook(hook)
```
"""
def freeze_parameters_hook(gradients):
for start, end in ranges_to_freeze:
gradients[start:end].zero_()
return freeze_parameters_hook
class LayerNamePattern:
"""
Represents a regex pattern for layer names, potentially including a parameter index range.
"""
def __init__(self, pattern: str):
"""
Initializes a new instance of the LayerNamePattern class.
Parameters:
- pattern (str): The regex pattern for layer names, potentially including a parameter index range.
"""
self.raw_pattern = pattern
name_pattern, self.range = self._parse_pattern(pattern)
self.name_regex = re.compile(name_pattern.replace(".", "\\."))
def match(self, name: str) -> bool:
"""
Checks if the given layer name matches the regex pattern.
Parameters:
- name (str): The layer name to check.
Returns:
- bool: True if the layer name matches the pattern, False otherwise.
"""
return self.name_regex.match(name) is not None
def _parse_pattern(self, pattern: str) -> Tuple[str, Tuple[int, int | None] | None]:
"""
Extracts the range pattern from the given pattern.
Parameters:
- pattern (str): The pattern to extract the range from.
Returns:
- Tuple[str, Tuple[int, int | None] | None]: A tuple containing the regex pattern to match the layer name without the range pattern and the range of layer indices to match, if specified.
"""
match = re.match(r"^(.+)\[([0-9]*)(?::([0-9]*))?\]$", pattern)
if not match:
return pattern, None
base_pattern, start_part, end_part = match.groups()
if end_part is None and start_part.isdecimal():
index = int(start_part)
return base_pattern, (index, index + 1)
# [:end] or [start:] or [start:end]
start = int(start_part) if start_part else 0
end = int(end_part) if end_part else None
if end is not None and start >= end:
raise ValueError(
f"Invalid range in layer name pattern: {pattern}."
"End of range must be greater than start."
)
return base_pattern, (start, end)

View File

@@ -1,13 +1,20 @@
"""Module for models and model loading"""
# pylint: disable=too-many-lines
import logging
import math
import os
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
import types
from typing import Any, Dict, List, Optional, Tuple, Type, Union # noqa: F401
import addict
import bitsandbytes as bnb
import safetensors
import torch
import transformers
from accelerate import init_empty_weights
from bitsandbytes.nn import Linear4bit, Params4bit
from fastcore.parallel import parallel
from peft import (
LoftQConfig,
PeftConfig,
@@ -16,6 +23,7 @@ from peft import (
prepare_model_for_kbit_training,
)
from peft.tuners.lora import QuantLinear
from torch import Tensor, nn
from transformers import ( # noqa: F401
AddedToken,
AutoConfig,
@@ -27,7 +35,9 @@ from transformers import ( # noqa: F401
PreTrainedTokenizerBase,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
from axolotl.core.policies.auto_wrap import SUPPORTED_AUTO_WRAP_MODEL_TYPES
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -262,6 +272,117 @@ def load_tokenizer(cfg):
return tokenizer
def replace_linear(
model: nn.Module,
linear_replacement: Type[nn.Module],
quant_config: Union[dict, None] = None,
skip_modules=None,
**kwargs,
):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
"""
if skip_modules is None:
skip_modules = ["lm_head"]
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_linear(
module, linear_replacement, quant_config, skip_modules, **kwargs
)
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
if issubclass(linear_replacement, Linear4bit):
model._modules[ # pylint: disable=protected-access
name
] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
**kwargs,
)
else:
raise ValueError(
f"Unsupported linear replacement: {type(linear_replacement)}"
)
return model
def load_and_quantize(
module: nn.Module,
name: str,
value: Tensor,
device: torch.device = None,
dtype: torch.dtype = None,
skip_names: Optional[List[str]] = None,
is_meta_rank: bool = False,
low_memory: bool = True,
verbose: bool = False,
quant_method: str = "bnb",
):
"""
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
"""
if skip_names is None:
skip_names = []
def place_on_device(value):
if is_meta_rank:
device = "meta"
elif low_memory:
device = "cpu"
else:
device = "cuda"
return value.to(device=device, dtype=dtype)
if any(skip_name in name for skip_name in skip_names):
if verbose:
print(f"Skipping {name} because it is in skip_names")
return
module_key, _, value_key = name.rpartition(".")
try:
submodule = module.get_submodule(module_key)
except AttributeError as exc:
print(f"Module {module_key} not found:\n{exc}")
return
try:
if quant_method == "bnb":
param = submodule.get_parameter(value_key)
if isinstance(param, Params4bit):
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
# shape as the quantized Params4bit with an initialized quant_state. However,
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
value = type(param)(
value.to(device=device, dtype=dtype).data, **param.__dict__
).cuda(device)
if is_meta_rank:
value = type(param)(value.data.to("meta"), **value.__dict__)
elif low_memory:
value = type(param)(value.data.to("cpu"), **value.__dict__)
else:
value = type(param)(place_on_device(value).data)
except AttributeError:
# it's a buffer
value = place_on_device(value)
setattr(submodule, value_key, value)
def load_model(
cfg: DictDefault,
tokenizer: PreTrainedTokenizerBase,
@@ -308,7 +429,7 @@ def load_model(
and cfg.flash_attention
and cfg.sample_packing
):
patch_for_multipack(cfg.model_config_type)
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block
@@ -394,7 +515,7 @@ def load_model(
if max_memory is not None:
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
from accelerate import infer_auto_device_map, init_empty_weights
from accelerate import infer_auto_device_map
with init_empty_weights():
model_canvas = AutoModelForCausalLM.from_config(model_config)
@@ -496,8 +617,78 @@ def load_model(
model_kwargs["attn_implementation"] = "eager"
model_config._attn_implementation = "eager" # pylint: disable=protected-access
qlora_fsdp = (
cfg.fsdp
and cfg.adapter == "qlora"
and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES
)
try:
if (
if qlora_fsdp:
if cfg.bf16 or cfg.bfloat16:
torch_dtype, compute_dtype = torch.float32, torch.bfloat16
elif cfg.fp16 or cfg.float16:
torch_dtype, compute_dtype = torch.float32, torch.float16
else:
torch_dtype, compute_dtype = torch.float32, torch.float16
with init_empty_weights():
LOG.info("Loading model with empty weights.")
model = AutoModelForCausalLM.from_config(model_config)
model.model = replace_linear(
model.model,
Linear4bit,
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=torch_dtype,
)
model.is_loaded_in_4bit = True
# Grab the safetensors files that hold the weights
try:
idx = hub.cached_file(base_model, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(base_model, idx)
except OSError:
try:
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
files = []
files.append(hub.cached_file(base_model, SAFE_WEIGHTS_NAME))
except OSError as exc:
# This means the model probably doesn't have a safetensors file
raise exc
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
def load_and_quantize_parallel(name_param, model, **kwargs):
name, param = name_param
load_and_quantize(model, name, param, **kwargs)
param_count = sum((p.numel() for n, p in model.named_parameters()))
for filename in files:
weights = safetensors.torch.load_file(filename)
quant_method = "bnb"
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
left = int(os.cpu_count() / torch.cuda.device_count())
right = int(
8 * (devprops.total_memory / 1e9 / 40) * (70 / (param_count / 1e9))
)
n_workers = min(left, right)
parallel(
load_and_quantize_parallel,
weights.items(),
n_workers=n_workers,
threadpool=True,
model=model,
dtype=torch_dtype,
device=cfg.local_rank,
skip_names=[],
is_meta_rank=(cfg.local_rank != 0),
verbose=False,
quant_method=quant_method,
)
elif (
model_config.model_type == "llama"
and not cfg.trust_remote_code
and not cfg.gptq
@@ -524,32 +715,27 @@ def load_model(
if cfg.flash_attn_fuse_qkv:
LOG.info("patching with fused QKV")
replace_llama_qkv_with_fused(model)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
# This is a WIP, still an issue with the backward pass
# RuntimeError: grad can be implicitly created only for scalar outputs
# TODO: try config.sequence_parallel = False
# # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12
# # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components
# # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442
# from flash_attn.utils.pretrained import state_dict_from_pretrained
# from flash_attn.models.gpt import GPTLMHeadModel
# from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
# from transformers import GPTNeoXConfig
# config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model))
# config.use_flash_attn = True
# config.fused_bias_fc = True
# config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
# config.activation_function = "gelu_fast"
# config.fused_dropout_add_ln = True
# # config.residual_in_fp32 = True
#
# model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained(
# base_model,
# config,
# dtype=torch_dtype,
# device=cfg.device,
# )
# model.train() # sets to train instead of eval mode
elif (
model_config.model_type == "mixtral"
and not cfg.adapter
and cfg.fuse_moe
):
from axolotl.monkeypatch.utils import set_module_name
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
for name, module in model.named_modules():
if isinstance(module, MixtralSparseMoeBlock):
smoe = SparseMoeBlock(
experts=module.experts,
gate=module.gate,
hidden_dim=module.hidden_dim,
ffn_dim=module.ffn_dim,
num_experts=module.num_experts,
top_k=module.top_k,
)
set_module_name(model, name, smoe)
elif model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
@@ -613,7 +799,7 @@ def load_model(
LOG.exception(err)
raise err
if isinstance(model, (PeftModel, PeftModelForCausalLM)):
if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp:
model = model.merge_and_unload()
embeddings_len = (
@@ -692,6 +878,9 @@ def load_model(
if cfg.adapter == "lora" and loftq_bits:
skip_prepare_model_for_kbit_training = True
if qlora_fsdp:
skip_prepare_model_for_kbit_training = True
if cfg.adapter in ["lora", "qlora"]:
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable()
@@ -706,7 +895,7 @@ def load_model(
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if needs_fa2_dtype or cfg.flash_attention:
if (needs_fa2_dtype or cfg.flash_attention) and not qlora_fsdp:
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules():
if "norm" in name:
@@ -724,7 +913,12 @@ def load_model(
else:
model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
if (
cfg.ddp
and not load_in_8bit
and not (cfg.rl and cfg.load_in_4bit)
and not qlora_fsdp
):
# TODO revaldate this conditional
model.to(f"cuda:{cfg.local_rank}")
@@ -813,6 +1007,30 @@ def find_all_linear_names(model):
return list(lora_module_names)
def setup_quantized_meta_for_peft(model: nn.Module):
"""Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device"""
def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument
return self
for param in model.parameters():
if isinstance(param, Params4bit):
param.quant_state._orig_to = ( # pylint: disable=protected-access
param.quant_state.to
)
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
def setup_quantized_peft_meta_for_training(model: nn.Module):
"""Replaces dummy `quant_state.to` method with the original function to allow training to continue"""
for param in model.parameters():
if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"):
param.quant_state.to = (
param.quant_state._orig_to # pylint: disable=protected-access
)
param.quant_state._orig_to = None # pylint: disable=protected-access
def load_lora(model, cfg, inference=False, config_only=False):
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
@@ -832,6 +1050,8 @@ def load_lora(model, cfg, inference=False, config_only=False):
lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
if cfg.peft_use_rslora:
lora_config_kwargs["use_rslora"] = cfg.use_rslora
lora_config = LoraConfig(
r=cfg.lora_r,
@@ -849,6 +1069,11 @@ def load_lora(model, cfg, inference=False, config_only=False):
if config_only:
return None, lora_config
rank = int(os.environ.get("LOCAL_RANK", 0))
if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
setup_quantized_meta_for_peft(model)
if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - LoRA")
model_kwargs: Any = {}
@@ -864,6 +1089,9 @@ def load_lora(model, cfg, inference=False, config_only=False):
else:
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
if rank == 0:
model.print_trainable_parameters()
elif cfg.fsdp and cfg.adapter == "qlora":
setup_quantized_peft_meta_for_training(model)
return model, lora_config

View File

@@ -2,6 +2,8 @@
import logging
import re
from typing import Dict, List
from termcolor import colored
@@ -36,3 +38,65 @@ def check_example_labels(example, tokenizer, text_only=False):
LOG.info("\n\n\n")
return " ".join(colored_tokens)
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
GLAIVE_TO_SHAREGPT_ROLE = {
"SYSTEM": "system",
"USER": "human",
"ASSISTANT": "gpt",
"FUNCTION RESPONSE": "tool",
}
GLAIVE_MSG_REGEX = re.compile(rf"({'|'.join(GLAIVE_ROLES)}): ")
def chatml_to_conversation(row: Dict[str, str]) -> List[Dict[str, str]]:
"""
Converts a ChatML formatted row to a list of messages in ShareGPT format.
Initially based off https://github.com/lilacai/lilac/blob/main/notebooks/GlaiveToShareGPT.ipynb.
"""
system_prompt = row.get("system")
if system_prompt:
system_prompt = system_prompt.removeprefix("SYSTEM: ")
chat_str = row["chat"]
chat_msgs = [s.strip() for s in GLAIVE_MSG_REGEX.split(chat_str) if s]
chat_msg_dicts = [
{"from": GLAIVE_TO_SHAREGPT_ROLE[role], "value": value}
for role, value in zip(chat_msgs[::2], chat_msgs[1::2])
]
if system_prompt:
chat_msg_dicts = [
{"from": GLAIVE_TO_SHAREGPT_ROLE["SYSTEM"], "value": system_prompt}
] + chat_msg_dicts
return chat_msg_dicts
def merge_consecutive_messages(messages):
"""
Merge consecutive messages from the same sender into a single message.
This can be useful with datasets that contain multiple consecutive tool calls.
"""
merged_messages = []
current_from = None
current_message = ""
for msg in messages:
if current_from == msg["from"]:
current_message += msg["value"]
else:
if current_from is not None:
merged_messages.append({"from": current_from, "value": current_message})
current_from = msg["from"]
current_message = msg["value"]
if current_from is not None:
merged_messages.append({"from": current_from, "value": current_message})
return merged_messages

View File

@@ -0,0 +1,60 @@
import torch
import pytest
from torch import nn
from torch.nn import functional as F
from axolotl.monkeypatch.moe.mlp import FusedExperts
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralConfig
def test_fused_mixtral_moe():
# NOTE: Requires torch 2.2.0
# Set random seeds for reproducibility
torch.set_default_dtype(torch.float16)
torch.set_default_device("cuda")
torch.manual_seed(0)
# Define the configuration for the MixtralSparseMoeBlock
config = MixtralConfig(
hidden_size=128,
intermediate_size=512,
num_local_experts=8,
num_experts_per_tok=2,
)
# Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration
mixtral_moe = MixtralSparseMoeBlock(config)
sparse_moe = SparseMoeBlock(
experts=mixtral_moe.experts,
gate=mixtral_moe.gate,
hidden_dim=config.hidden_size,
ffn_dim=config.intermediate_size,
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok
)
assert torch.cat([
mixtral_moe.experts[0].w1.weight.data,
mixtral_moe.experts[0].w3.weight.data], dim=0
).equal(sparse_moe.experts.experts.weight[0])
# Generate random input data
batch_size = 16
sequence_length = 32
input_data = torch.randn(batch_size, sequence_length, config.hidden_size)
# Run the forward pass with gradients for both models
with torch.no_grad():
mixtral_output, mixtral_router_logits = mixtral_moe(input_data)
sparse_output, sparse_router_logits = sparse_moe(input_data)
# Compute the difference between the outputs
output_diff = torch.abs(mixtral_output - sparse_output).mean().item()
router_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item()
# Define the tolerance for the difference
tolerance = 0.05
# # Check if the difference is within the tolerance
assert output_diff < 0.05, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
assert router_diff == 0, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"

View File

@@ -1,6 +1,7 @@
"""
Test module for sharegpt integration w chatml
"""
import pytest
from datasets import Dataset
from tokenizers import AddedToken
@@ -8,6 +9,7 @@ from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.sharegpt import (
GlaiveShareGPTPromptTokenizingStrategy,
SimpleShareGPTPromptTokenizingStrategy,
register_chatml_template,
)
@@ -48,6 +50,18 @@ def fixture_sharegpt_dataset():
)
@pytest.fixture(name="glaive_dataset")
def fixture_sharegpt_glaive_dataset():
return Dataset.from_list(
[
{
"system": "SYSTEM: This is a system prompt",
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
}
]
)
@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
@@ -156,3 +170,29 @@ class TestSharegpt:
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
]
# fmt: on
def test_chatml_glaive(self, glaive_dataset, tokenizer):
strategy = GlaiveShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation="chatml",
role_key_model=None,
role_key_human=None,
),
tokenizer,
True, # train_on_inputs
2048, # sequence_len
)
dataset_wrapper = TokenizedPromptDataset(
strategy, glaive_dataset, process_count=1
)
labels = dataset_wrapper[0]["labels"]
# fmt: off
assert labels == [
1, # bos
32001, 1587, 13, 3260, 349, 264, 1587, 11510, 32000, 28705, 13, # system
32001, 2188, 13, 6325, 368, 1820, 264, 9314, 354, 528, 477, 1450, 2726, 298, 4222, 28804, 32000, 28705, 13, # human
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
]
# fmt: on

285
tests/test_freeze.py Normal file
View File

@@ -0,0 +1,285 @@
"""
This module contains unit tests for the `freeze_layers_except` function.
The `freeze_layers_except` function is used to freeze layers in a model, except for the specified layers.
The unit tests in this module verify the behavior of the `freeze_layers_except` function in different scenarios.
"""
import unittest
import torch
from torch import nn
from axolotl.utils.freeze import freeze_layers_except
ZERO = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
ONE_TO_TEN = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
class TestFreezeLayersExcept(unittest.TestCase):
"""
A test case class for the `freeze_layers_except` function.
"""
def setUp(self):
self.model = _TestModel()
def test_freeze_layers_with_dots_in_name(self):
freeze_layers_except(self.model, ["features.layer"])
self.assertTrue(
self.model.features.layer.weight.requires_grad,
"model.features.layer should be trainable.",
)
self.assertFalse(
self.model.classifier.weight.requires_grad,
"model.classifier should be frozen.",
)
def test_freeze_layers_without_dots_in_name(self):
freeze_layers_except(self.model, ["classifier"])
self.assertFalse(
self.model.features.layer.weight.requires_grad,
"model.features.layer should be trainable.",
)
self.assertTrue(
self.model.classifier.weight.requires_grad,
"model.classifier should be frozen.",
)
def test_freeze_layers_regex_patterns(self):
# The second pattern cannot match because only characters 'a' to 'c' are allowed after the word 'class', whereas it should be matching the character 'i'.
freeze_layers_except(self.model, [r"^features.[a-z]+.weight$", r"class[a-c]+"])
self.assertTrue(
self.model.features.layer.weight.requires_grad,
"model.features.layer should be trainable.",
)
self.assertFalse(
self.model.classifier.weight.requires_grad,
"model.classifier should be frozen.",
)
def test_all_layers_frozen(self):
freeze_layers_except(self.model, [])
self.assertFalse(
self.model.features.layer.weight.requires_grad,
"model.features.layer should be frozen.",
)
self.assertFalse(
self.model.classifier.weight.requires_grad,
"model.classifier should be frozen.",
)
def test_all_layers_unfrozen(self):
freeze_layers_except(self.model, ["features.layer", "classifier"])
self.assertTrue(
self.model.features.layer.weight.requires_grad,
"model.features.layer should be trainable.",
)
self.assertTrue(
self.model.classifier.weight.requires_grad,
"model.classifier should be trainable.",
)
def test_freeze_layers_with_range_pattern_start_end(self):
freeze_layers_except(self.model, ["features.layer[1:5]"])
self.assertTrue(
self.model.features.layer.weight.requires_grad,
"model.features.layer should be trainable.",
)
self.assertFalse(
self.model.classifier.weight.requires_grad,
"model.classifier should be frozen.",
)
self._assert_gradient_output(
[
ZERO,
ONE_TO_TEN,
ONE_TO_TEN,
ONE_TO_TEN,
ONE_TO_TEN,
ZERO,
ZERO,
ZERO,
ZERO,
ZERO,
]
)
def test_freeze_layers_with_range_pattern_single_index(self):
freeze_layers_except(self.model, ["features.layer[5]"])
self.assertTrue(
self.model.features.layer.weight.requires_grad,
"model.features.layer should be trainable.",
)
self.assertFalse(
self.model.classifier.weight.requires_grad,
"model.classifier should be frozen.",
)
self._assert_gradient_output(
[ZERO, ZERO, ZERO, ZERO, ZERO, ONE_TO_TEN, ZERO, ZERO, ZERO, ZERO]
)
def test_freeze_layers_with_range_pattern_start_omitted(self):
freeze_layers_except(self.model, ["features.layer[:5]"])
self.assertTrue(
self.model.features.layer.weight.requires_grad,
"model.features.layer should be trainable.",
)
self.assertFalse(
self.model.classifier.weight.requires_grad,
"model.classifier should be frozen.",
)
self._assert_gradient_output(
[
ONE_TO_TEN,
ONE_TO_TEN,
ONE_TO_TEN,
ONE_TO_TEN,
ONE_TO_TEN,
ZERO,
ZERO,
ZERO,
ZERO,
ZERO,
]
)
def test_freeze_layers_with_range_pattern_end_omitted(self):
freeze_layers_except(self.model, ["features.layer[4:]"])
self.assertTrue(
self.model.features.layer.weight.requires_grad,
"model.features.layer should be trainable.",
)
self.assertFalse(
self.model.classifier.weight.requires_grad,
"model.classifier should be frozen.",
)
self._assert_gradient_output(
[
ZERO,
ZERO,
ZERO,
ZERO,
ONE_TO_TEN,
ONE_TO_TEN,
ONE_TO_TEN,
ONE_TO_TEN,
ONE_TO_TEN,
ONE_TO_TEN,
]
)
def test_freeze_layers_with_range_pattern_merge_included(self):
freeze_layers_except(self.model, ["features.layer[4:]", "features.layer[5:6]"])
self.assertTrue(
self.model.features.layer.weight.requires_grad,
"model.features.layer should be trainable.",
)
self.assertFalse(
self.model.classifier.weight.requires_grad,
"model.classifier should be frozen.",
)
self._assert_gradient_output(
[
ZERO,
ZERO,
ZERO,
ZERO,
ONE_TO_TEN,
ONE_TO_TEN,
ONE_TO_TEN,
ONE_TO_TEN,
ONE_TO_TEN,
ONE_TO_TEN,
]
)
def test_freeze_layers_with_range_pattern_merge_intersect(self):
freeze_layers_except(self.model, ["features.layer[4:7]", "features.layer[6:8]"])
self.assertTrue(
self.model.features.layer.weight.requires_grad,
"model.features.layer should be trainable.",
)
self.assertFalse(
self.model.classifier.weight.requires_grad,
"model.classifier should be frozen.",
)
self._assert_gradient_output(
[
ZERO,
ZERO,
ZERO,
ZERO,
ONE_TO_TEN,
ONE_TO_TEN,
ONE_TO_TEN,
ONE_TO_TEN,
ZERO,
ZERO,
]
)
def test_freeze_layers_with_range_pattern_merge_separate(self):
freeze_layers_except(
self.model,
["features.layer[1:2]", "features.layer[3:4]", "features.layer[5:6]"],
)
self.assertTrue(
self.model.features.layer.weight.requires_grad,
"model.features.layer should be trainable.",
)
self.assertFalse(
self.model.classifier.weight.requires_grad,
"model.classifier should be frozen.",
)
self._assert_gradient_output(
[
ZERO,
ONE_TO_TEN,
ZERO,
ONE_TO_TEN,
ZERO,
ONE_TO_TEN,
ZERO,
ZERO,
ZERO,
ZERO,
]
)
def _assert_gradient_output(self, expected):
input_tensor = torch.tensor([ONE_TO_TEN], dtype=torch.float32)
self.model.features.layer.weight.grad = None # Reset gradients
output = self.model.features.layer(input_tensor)
loss = output.sum()
loss.backward()
expected_grads = torch.tensor(expected)
torch.testing.assert_close(
self.model.features.layer.weight.grad, expected_grads
)
class _SubLayerModule(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(10, 10)
class _TestModel(nn.Module):
def __init__(self):
super().__init__()
self.features = _SubLayerModule()
self.classifier = nn.Linear(10, 2)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,4 +1,5 @@
"""Module for testing prompt tokenizers."""
import json
import logging
import unittest
@@ -18,6 +19,7 @@ from axolotl.prompt_strategies.llama2_chat import (
Llama2ChatPrompter,
LLama2ChatTokenizingStrategy,
)
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
@@ -266,6 +268,23 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
idx = res["input_ids"].index(20255) # assistant token
assert res["labels"][idx] == -100
def test_glaive_tool_label_ignore(self):
conversation = {
"system": "SYSTEM: This is a system prompt",
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
}
prompter = ShareGPTPrompterV2()
strat = GlaiveShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
res = strat.tokenize_prompt(conversation)
idx = res["input_ids"].index(13566) # assistant token
assert res["labels"][idx] == -100
def test_no_sys_prompt(self):
"""
tests the interface between the user and assistant parts