2d parallel llama fsdp

This commit is contained in:
Wing Lian
2024-08-23 00:02:14 -04:00
parent fefa95e350
commit 198f7cd893

View File

@@ -20,6 +20,14 @@ from typing import Dict, List, Literal, Optional, Type, Union
import torch
import transformers
from datasets import Dataset
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
parallelize_module,
)
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
@@ -1233,6 +1241,19 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["fsdp_config"] = {
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
}
# FIXME: hardcoded testing sizes
tp_size = int(os.environ.get("FSDP_TP_SIZE", 0))
if tp_size > 0:
world_size = int(os.environ.get("WORLD_SIZE", 1))
dp_size = world_size // tp_size
from torch.distributed.device_mesh import init_device_mesh
device_mesh = init_device_mesh(
"cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")
)
dp_mesh = device_mesh["dp"]
training_arguments_kwargs["fsdp_config"]["device_mesh"] = dp_mesh
self.parallelize_model(device_mesh)
if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True
@@ -1605,6 +1626,60 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return trainer
def parallelize_model(self, device_mesh, loss_parallel=True):
# FIXME hardcoded for llama
tp_mesh = device_mesh["tp"]
parallelize_module(
self.model,
tp_mesh,
{
"lm_head": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
},
)
parallelize_module(
self.model.model,
tp_mesh,
{
"embed_tokens": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
},
)
for _, transformer_block in self.model.model.layers.items():
layer_plan = {
"input_layernorm": SequenceParallel(),
"self_attn": PrepareModuleInput(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
"self_attn.q_proj": ColwiseParallel(),
"self_attn.k_proj": ColwiseParallel(),
"self_attn.v_proj": ColwiseParallel(),
"self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)),
"post_attention_layernorm": SequenceParallel(),
"mlp": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"mlp.gate_proj": ColwiseParallel(),
"mlp.up_proj": RowwiseParallel(output_layouts=Shard(1)),
"mlp.down_proj": ColwiseParallel(),
}
parallelize_module(
transformer_block,
tp_mesh,
layer_plan,
)
def build_collator(
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):