From 7a09f76644fd429b047549c494d5667ad3fd6098 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 11 Aug 2025 09:31:54 -0400 Subject: [PATCH] fix ray train and add fsdp2 smoke test for ray trainer (#3053) * add fsdp2 smokle test for ray trainer * fix raytrain with fsdp2 --- src/axolotl/cli/utils/train.py | 3 ++ tests/e2e/multigpu/test_ray.py | 74 +++++++++++++++++++++++++++++++++- 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/src/axolotl/cli/utils/train.py b/src/axolotl/cli/utils/train.py index 3f9a6e4db..f1ac857b3 100644 --- a/src/axolotl/cli/utils/train.py +++ b/src/axolotl/cli/utils/train.py @@ -123,6 +123,9 @@ def launch_training( _launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec) elif launcher == "python": _launch_python_training(cfg_file, kwargs) + elif launcher is None: + # handle ray train launch + _launch_python_training(cfg_file, kwargs) def _launch_cloud_training( diff --git a/tests/e2e/multigpu/test_ray.py b/tests/e2e/multigpu/test_ray.py index dd1422296..7f1278abf 100644 --- a/tests/e2e/multigpu/test_ray.py +++ b/tests/e2e/multigpu/test_ray.py @@ -10,7 +10,11 @@ from accelerate.test_utils import execute_subprocess_async from axolotl.utils.dict import DictDefault -from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0 +from tests.e2e.utils import ( + check_tensorboard, + require_torch_2_7_0, + require_torch_lt_2_6_0, +) AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent @@ -139,3 +143,71 @@ class TestMultiGPURay: check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) + + @require_torch_2_7_0 + @pytest.mark.parametrize( + "gradient_accumulation_steps", + [1, 2], + ) + def test_sft_fsdp2_packed(self, temp_dir, gradient_accumulation_steps): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sample_packing": True, + "pad_to_sequence_len": True, + "sequence_len": 1024, + "val_set_size": 0.01, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 1, + "gradient_accumulation_steps": gradient_accumulation_steps, + "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "LlamaDecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "save_first_step": False, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--use-ray", + "--ray-num-workers", + "2", + ] + ) + + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" + )