diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 8e0508ee0..82a46d9cf 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -42,6 +42,7 @@ class LigerPlugin(BasePlugin): def pre_model_load(self, cfg): if cfg.torch_compile: + # torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled import liger_kernel.ops.fused_linear_cross_entropy patch_with_compile_disable( diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index ee1869f7d..255422215 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -450,6 +450,76 @@ class TestMultiGPULlama: temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" ) + @pytest.mark.parametrize( + "fsdp_reshard_after_forward", + [True, False], + ) + def test_fsdp2_packed(self, temp_dir, fsdp_reshard_after_forward): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sample_packing": True, + "pad_to_sequence_len": True, + "sequence_len": 2048, + "val_set_size": 0.05, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "gradient_checkpointing": True, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp": [ + "auto_wrap", + ], + "fsdp_config": { + "fsdp_version": 2, + "fsdp_limit_all_gathers": True, + "fsdp_offload_params": False, + "fsdp_cpu_ram_efficient_loading": False, + "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer", + "fsdp_state_dict_type": "SHARDED_STATE_DICT", + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_reshard_after_forward": fsdp_reshard_after_forward, + }, + "use_tensorboard": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + ) + def test_fsdp_qlora_prequant_packed(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault(