Compare commits

...

3 Commits

Author SHA1 Message Date
Wing Lian
3ade0b81db add example yaml 2024-09-01 21:20:48 -04:00
Wing Lian
756a34f0fe wip for tp 2024-08-23 10:57:57 -04:00
Wing Lian
198f7cd893 2d parallel llama fsdp 2024-08-23 00:02:14 -04:00
2 changed files with 145 additions and 0 deletions

View File

@@ -0,0 +1,62 @@
base_model: nvidia/Llama-3.1-Minitron-4B-Width-Base
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train
train_on_eos: turn
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
wandb_project: device_mesh-test
wandb_entity: axolotl-ai
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: true
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
eager_attention:
warmup_steps: 100
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- auto_wrap
fsdp_config:
fsdp_use_orig_params: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -20,6 +20,14 @@ from typing import Dict, List, Literal, Optional, Type, Union
import torch import torch
import transformers import transformers
from datasets import Dataset from datasets import Dataset
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
parallelize_module,
)
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import ( from transformers import (
@@ -1233,6 +1241,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["fsdp_config"] = { training_arguments_kwargs["fsdp_config"] = {
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items() k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
} }
# FIXME: hardcoded testing sizes
tp_size = int(os.environ.get("FSDP_TP_SIZE", 0))
if tp_size > 0:
world_size = int(os.environ.get("WORLD_SIZE", 1))
dp_size = world_size // tp_size
from torch.distributed.device_mesh import init_device_mesh
device_mesh = init_device_mesh(
"cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")
)
dp_mesh = device_mesh["dp"]
tp_mesh = device_mesh["tp"]
training_arguments_kwargs["fsdp_config"]["device_mesh"] = dp_mesh
self.parallelize_model(tp_mesh)
if self.cfg.adapter == "qlora": if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True training_arguments_kwargs["qlora"] = True
@@ -1605,6 +1627,67 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return trainer return trainer
def parallelize_model(self, device_mesh, loss_parallel=False):
# FIXME hardcoded for llama
tp_mesh = device_mesh["tp"]
parallelize_module(
self.model,
tp_mesh,
{
"lm_head": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
},
)
parallelize_module(
self.model.model,
tp_mesh,
{
"embed_tokens": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
},
)
for _, transformer_block in enumerate(self.model.model.layers):
layer_plan = {
"input_layernorm": SequenceParallel(),
"self_attn": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate()),
),
"self_attn.q_proj": ColwiseParallel(),
"self_attn.k_proj": ColwiseParallel(),
"self_attn.v_proj": ColwiseParallel(),
"self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)),
"post_attention_layernorm": SequenceParallel(),
"mlp": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"mlp.gate_proj": ColwiseParallel(),
"mlp.up_proj": ColwiseParallel(),
"mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)),
}
self_attn = transformer_block.self_attn
self_attn.num_heads = self_attn.num_heads // tp_mesh.size()
self_attn.num_key_value_heads = (
self_attn.num_key_value_heads // tp_mesh.size()
)
# TODO need to fix self_attn.rotary_emb
parallelize_module(
transformer_block,
tp_mesh,
layer_plan,
)
def build_collator( def build_collator(
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
): ):