Implement fused modules (#747)
* MLP: Memory saving * Remove RMSNorm restrictions * Map packed weights to original * FusedAttention module * Simplify code * Move fused modules * Fix critical typo * Split inplace * Add FFT config * Add validation of fused arguments * Add fused arguments to config * Update docs * Fix validation logic * Add fused modules to flash attn * Only fuse during training * Remove timing * Formatting * Formatting * Formatting * chore: lint * chore: lint * add e2e tests for fused llama * no lora for tests --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -684,6 +684,8 @@ xformers_attention:
|
|||||||
flash_attention:
|
flash_attention:
|
||||||
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||||
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||||
|
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
|
||||||
|
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
||||||
# Whether to use scaled-dot-product attention
|
# Whether to use scaled-dot-product attention
|
||||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||||
sdp_attention:
|
sdp_attention:
|
||||||
|
|||||||
@@ -9,12 +9,16 @@ gradient_accumulation_steps: 2
|
|||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
accelerate launch scripts/finetune.py examples/llama-2/qlora.yml
|
accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml
|
||||||
|
|
||||||
```
|
```
|
||||||
or
|
or
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
accelerate launch scripts/finetune.py examples/llama-2/lora.yml
|
accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
To launch a full finetuning with 16-bit precision:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml
|
||||||
```
|
```
|
||||||
|
|||||||
73
examples/llama-2/fft_optimized.yml
Normal file
73
examples/llama-2/fft_optimized.yml
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
base_model: NousResearch/Llama-2-7b-hf
|
||||||
|
base_model_config: NousResearch/Llama-2-7b-hf
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
is_llama_derived_model: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.01
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter:
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r:
|
||||||
|
lora_alpha:
|
||||||
|
lora_dropout:
|
||||||
|
lora_target_linear:
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_run_id:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
flash_attn_cross_entropy: false
|
||||||
|
flash_attn_rms_norm: true
|
||||||
|
flash_attn_fuse_qkv: false
|
||||||
|
flash_attn_fuse_mlp: true
|
||||||
|
|
||||||
|
warmup_steps: 100
|
||||||
|
eval_steps: 0.05
|
||||||
|
eval_table_size:
|
||||||
|
save_steps:
|
||||||
|
debug:
|
||||||
|
deepspeed: #deepspeed/zero2.json # multi-gpu only
|
||||||
|
weight_decay: 0.1
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
bos_token: "<s>"
|
||||||
|
eos_token: "</s>"
|
||||||
|
unk_token: "<unk>"
|
||||||
0
src/axolotl/monkeypatch/fused_modules.py
Normal file
0
src/axolotl/monkeypatch/fused_modules.py
Normal file
@@ -13,12 +13,18 @@ import transformers
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
||||||
)
|
)
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaMLP,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
repeat_kv,
|
||||||
|
)
|
||||||
|
from xformers.ops import SwiGLU
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||||
@@ -38,6 +44,28 @@ except ImportError:
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
def replace_llama_mlp_with_swiglu(model):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, LlamaMLP):
|
||||||
|
mlp = FusedMLP(
|
||||||
|
module.config, module.gate_proj, module.up_proj, module.down_proj
|
||||||
|
)
|
||||||
|
set_module_name(model, name, mlp)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_llama_qkv_with_fused(model):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, LlamaAttention):
|
||||||
|
qkv = FusedAttention(
|
||||||
|
module.config,
|
||||||
|
module.q_proj,
|
||||||
|
module.k_proj,
|
||||||
|
module.v_proj,
|
||||||
|
module.o_proj,
|
||||||
|
)
|
||||||
|
set_module_name(model, name, qkv)
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_attn_with_flash_attn(
|
def replace_llama_attn_with_flash_attn(
|
||||||
packed: Optional[bool] = False,
|
packed: Optional[bool] = False,
|
||||||
cross_entropy: Optional[bool] = False,
|
cross_entropy: Optional[bool] = False,
|
||||||
@@ -86,6 +114,91 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedAttention(LlamaAttention):
|
||||||
|
"""
|
||||||
|
Fused QKV Attention layer for incrementally improved training efficiency
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
q: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
k: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
v: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
o: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.init_device = next(iter(q.state_dict().values())).device
|
||||||
|
|
||||||
|
# define equivalent fused qkv projection
|
||||||
|
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
|
||||||
|
self.qkv_proj = torch.nn.Linear(
|
||||||
|
q.in_features, sum(self.out_features), device=self.init_device, bias=False
|
||||||
|
)
|
||||||
|
self.o_proj = o
|
||||||
|
|
||||||
|
# overwrite initialized weights with pretrained weights
|
||||||
|
self.qkv_proj.weight.data = torch.cat(
|
||||||
|
(q.weight.data, k.weight.data, v.weight.data), dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _post_training(self, model, name):
|
||||||
|
q_proj, k_proj, v_proj = torch.split(
|
||||||
|
self.qkv_proj.weight.data, self.out_features, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
new_attn = LlamaAttention(self.config)
|
||||||
|
new_attn.q_proj.weight.data = q_proj
|
||||||
|
new_attn.k_proj.weight.data = k_proj
|
||||||
|
new_attn.v_proj.weight.data = v_proj
|
||||||
|
|
||||||
|
set_module_name(model, name, new_attn)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMLP(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Fused MLP layer for incrementally improved training efficiency
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
gate_proj: torch.nn.Linear,
|
||||||
|
up_proj: torch.nn.Linear,
|
||||||
|
down_proj: torch.nn.Linear,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.swiglu = SwiGLU(
|
||||||
|
in_features=config.hidden_size,
|
||||||
|
hidden_features=config.intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
_pack_weights=True,
|
||||||
|
)
|
||||||
|
# overwrite initialized weights with pretrained weights
|
||||||
|
self.swiglu.w12.weight.data = torch.cat(
|
||||||
|
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
||||||
|
)
|
||||||
|
self.swiglu.w3.weight.data = down_proj.weight.data
|
||||||
|
|
||||||
|
def _post_training(self, model, name):
|
||||||
|
w1, w2 = torch.split( # pylint: disable=invalid-name
|
||||||
|
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign the split weights back to the original layers
|
||||||
|
new_mlp = LlamaMLP(self.config)
|
||||||
|
new_mlp.gate_proj.weight.data = w1
|
||||||
|
new_mlp.up_proj.weight.data = w2
|
||||||
|
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
||||||
|
|
||||||
|
set_module_name(model, name, new_mlp)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
||||||
|
return self.swiglu(x)
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
# requires the attention mask to be the same as the key_padding_mask
|
# requires the attention mask to be the same as the key_padding_mask
|
||||||
def _prepare_decoder_attention_mask(
|
def _prepare_decoder_attention_mask(
|
||||||
@@ -147,9 +260,14 @@ def flashattn_forward(
|
|||||||
value_states = torch.cat(value_states, dim=-1)
|
value_states = torch.cat(value_states, dim=-1)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
query_states = self.q_proj(hidden_states)
|
if isinstance(self, FusedAttention):
|
||||||
key_states = self.k_proj(hidden_states)
|
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
|
||||||
value_states = self.v_proj(hidden_states)
|
self.out_features, dim=-1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
query_states = query_states.view(
|
query_states = query_states.view(
|
||||||
bsz, q_len, self.num_heads, self.head_dim
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
|||||||
@@ -101,3 +101,16 @@ def get_cu_seqlens_from_pos_ids(position_ids):
|
|||||||
max_seq_lens.append(max_seq_len)
|
max_seq_lens.append(max_seq_len)
|
||||||
|
|
||||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||||
|
|
||||||
|
|
||||||
|
def set_module_name(model, name, value):
|
||||||
|
if "." in name:
|
||||||
|
parent_name = name.rsplit(".", 1)[0]
|
||||||
|
child_name = name[len(parent_name) + 1 :]
|
||||||
|
parent = model.get_submodule(parent_name)
|
||||||
|
else:
|
||||||
|
parent_name = ""
|
||||||
|
parent = model
|
||||||
|
child_name = name
|
||||||
|
|
||||||
|
setattr(parent, child_name, value)
|
||||||
|
|||||||
@@ -40,10 +40,7 @@ class TrainDatasetMeta:
|
|||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
*,
|
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||||
cfg: DictDefault,
|
|
||||||
cli_args: TrainerCliArgs,
|
|
||||||
dataset_meta: TrainDatasetMeta,
|
|
||||||
):
|
):
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||||
@@ -120,6 +117,11 @@ def train(
|
|||||||
|
|
||||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||||
|
|
||||||
|
# post training
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if hasattr(module, "_post_training"):
|
||||||
|
module._post_training(model, name) # pylint: disable=protected-access
|
||||||
|
|
||||||
if trainer.is_fsdp_enabled:
|
if trainer.is_fsdp_enabled:
|
||||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
||||||
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
|
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
|
||||||
|
|||||||
@@ -189,9 +189,15 @@ def validate_config(cfg):
|
|||||||
if not cfg.load_in_4bit:
|
if not cfg.load_in_4bit:
|
||||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||||
|
|
||||||
|
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||||
|
raise ValueError("Fused modules are not supported with QLoRA")
|
||||||
|
|
||||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||||
|
|
||||||
|
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
||||||
|
raise ValueError("Fused modules are not supported with LoRA")
|
||||||
|
|
||||||
if cfg.relora_steps:
|
if cfg.relora_steps:
|
||||||
if cfg.adapter not in ("lora", "qlora"):
|
if cfg.adapter not in ("lora", "qlora"):
|
||||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||||
@@ -205,6 +211,9 @@ def validate_config(cfg):
|
|||||||
if cfg.lr_scheduler == "one_cycle":
|
if cfg.lr_scheduler == "one_cycle":
|
||||||
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
||||||
|
|
||||||
|
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||||
|
raise ValueError("Fused modules are not supported with ReLoRA")
|
||||||
|
|
||||||
if cfg.trust_remote_code:
|
if cfg.trust_remote_code:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||||
|
|||||||
@@ -272,6 +272,20 @@ def load_model(
|
|||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.flash_attention and not inference:
|
||||||
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
|
replace_llama_mlp_with_swiglu,
|
||||||
|
replace_llama_qkv_with_fused,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.flash_attn_fuse_mlp:
|
||||||
|
LOG.info("patching with SwiGLU")
|
||||||
|
replace_llama_mlp_with_swiglu(model)
|
||||||
|
|
||||||
|
if cfg.flash_attn_fuse_qkv:
|
||||||
|
LOG.info("patching with fused QKV")
|
||||||
|
replace_llama_qkv_with_fused(model)
|
||||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||||
# This is a WIP, still an issue with the backward pass
|
# This is a WIP, still an issue with the backward pass
|
||||||
# RuntimeError: grad can be implicitly created only for scalar outputs
|
# RuntimeError: grad can be implicitly created only for scalar outputs
|
||||||
|
|||||||
117
tests/e2e/test_fused_llama.py
Normal file
117
tests/e2e/test_fused_llama.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for lora llama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusedLlama(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models using Fused layers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_lora_packing(self):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
output_dir = tempfile.mkdtemp()
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"base_model_config": "JackFram/llama-68m",
|
||||||
|
"flash_attention": True,
|
||||||
|
"flash_attn_fuse_qkv": True,
|
||||||
|
"flash_attn_fuse_mlp": True,
|
||||||
|
"sample_packing": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": output_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
||||||
|
|
||||||
|
def test_fft_packing(self):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
output_dir = tempfile.mkdtemp()
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"base_model_config": "JackFram/llama-68m",
|
||||||
|
"flash_attention": True,
|
||||||
|
"flash_attn_fuse_qkv": True,
|
||||||
|
"flash_attn_fuse_mlp": True,
|
||||||
|
"sample_packing": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": output_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if is_torch_bf16_gpu_available():
|
||||||
|
cfg.bf16 = True
|
||||||
|
else:
|
||||||
|
cfg.fp16 = True
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
||||||
Reference in New Issue
Block a user