From 198f7cd8934fa177916469537e513fec491ca570 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 23 Aug 2024 00:02:14 -0400 Subject: [PATCH] 2d parallel llama fsdp --- src/axolotl/core/trainer_builder.py | 75 +++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 1a073ca04..264ab1d74 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -20,6 +20,14 @@ from typing import Dict, List, Literal, Optional, Type, Union import torch import transformers from datasets import Dataset +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, + parallelize_module, +) from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( @@ -1233,6 +1241,19 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["fsdp_config"] = { k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items() } + # FIXME: hardcoded testing sizes + tp_size = int(os.environ.get("FSDP_TP_SIZE", 0)) + if tp_size > 0: + world_size = int(os.environ.get("WORLD_SIZE", 1)) + dp_size = world_size // tp_size + from torch.distributed.device_mesh import init_device_mesh + + device_mesh = init_device_mesh( + "cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp") + ) + dp_mesh = device_mesh["dp"] + training_arguments_kwargs["fsdp_config"]["device_mesh"] = dp_mesh + self.parallelize_model(device_mesh) if self.cfg.adapter == "qlora": training_arguments_kwargs["qlora"] = True @@ -1605,6 +1626,60 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return trainer + def parallelize_model(self, device_mesh, loss_parallel=True): + # FIXME hardcoded for llama + tp_mesh = device_mesh["tp"] + + parallelize_module( + self.model, + tp_mesh, + { + "lm_head": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + parallelize_module( + self.model.model, + tp_mesh, + { + "embed_tokens": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + }, + ) + + for _, transformer_block in self.model.model.layers.items(): + layer_plan = { + "input_layernorm": SequenceParallel(), + "self_attn": PrepareModuleInput( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ), + "self_attn.q_proj": ColwiseParallel(), + "self_attn.k_proj": ColwiseParallel(), + "self_attn.v_proj": ColwiseParallel(), + "self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)), + "post_attention_layernorm": SequenceParallel(), + "mlp": PrepareModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "mlp.gate_proj": ColwiseParallel(), + "mlp.up_proj": RowwiseParallel(output_layouts=Shard(1)), + "mlp.down_proj": ColwiseParallel(), + } + parallelize_module( + transformer_block, + tp_mesh, + layer_plan, + ) + def build_collator( self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs ):