From 756a34f0fe55d9b06ca782f3a69dbb269db0cd81 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 23 Aug 2024 10:57:57 -0400 Subject: [PATCH] wip for tp --- src/axolotl/core/trainer_builder.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 264ab1d74..1bc8bd92a 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1252,8 +1252,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): "cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp") ) dp_mesh = device_mesh["dp"] + tp_mesh = device_mesh["tp"] training_arguments_kwargs["fsdp_config"]["device_mesh"] = dp_mesh - self.parallelize_model(device_mesh) + self.parallelize_model(tp_mesh) if self.cfg.adapter == "qlora": training_arguments_kwargs["qlora"] = True @@ -1626,7 +1627,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return trainer - def parallelize_model(self, device_mesh, loss_parallel=True): + def parallelize_model(self, device_mesh, loss_parallel=False): # FIXME hardcoded for llama tp_mesh = device_mesh["tp"] @@ -1641,7 +1642,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ), }, ) - parallelize_module( self.model.model, tp_mesh, @@ -1654,12 +1654,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): }, ) - for _, transformer_block in self.model.model.layers.items(): + for _, transformer_block in enumerate(self.model.model.layers): layer_plan = { "input_layernorm": SequenceParallel(), "self_attn": PrepareModuleInput( - input_layouts=(Shard(1), None), - desired_input_layouts=(Replicate(), None), + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate()), ), "self_attn.q_proj": ColwiseParallel(), "self_attn.k_proj": ColwiseParallel(), @@ -1671,9 +1671,17 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): desired_input_layouts=(Replicate(),), ), "mlp.gate_proj": ColwiseParallel(), - "mlp.up_proj": RowwiseParallel(output_layouts=Shard(1)), - "mlp.down_proj": ColwiseParallel(), + "mlp.up_proj": ColwiseParallel(), + "mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)), } + self_attn = transformer_block.self_attn + self_attn.num_heads = self_attn.num_heads // tp_mesh.size() + self_attn.num_key_value_heads = ( + self_attn.num_key_value_heads // tp_mesh.size() + ) + + # TODO need to fix self_attn.rotary_emb + parallelize_module( transformer_block, tp_mesh,