From 22680913f3ac4bf3410855210648f396cfd5c7d5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 27 Jul 2024 10:24:11 -0400 Subject: [PATCH] Bump deepspeed 20240727 (#1790) * pin deepspeed to 0.14.4 otherwise it doesn't play nice with trl * Add test to import to try to trigger import dependencies --- requirements.txt | 2 +- setup.py | 2 +- tests/e2e/test_imports.py | 20 ++++++++++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) create mode 100644 tests/e2e/test_imports.py diff --git a/requirements.txt b/requirements.txt index ec571570b..981a62558 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ transformers==4.43.1 tokenizers==0.19.1 bitsandbytes==0.43.1 accelerate==0.32.0 -deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b +deepspeed==0.14.4 pydantic==2.6.3 addict fire diff --git a/setup.py b/setup.py index ceba63669..1d164e0a1 100644 --- a/setup.py +++ b/setup.py @@ -86,7 +86,7 @@ setup( "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib", ], "deepspeed": [ - "deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b", + "deepspeed==0.14.4", "deepspeed-kernels", ], "mamba-ssm": [ diff --git a/tests/e2e/test_imports.py b/tests/e2e/test_imports.py new file mode 100644 index 000000000..f186eaac4 --- /dev/null +++ b/tests/e2e/test_imports.py @@ -0,0 +1,20 @@ +""" +test module to import various submodules that have historically broken due to dependency issues +""" +import unittest + + +class TestImports(unittest.TestCase): + """ + Test class to import various submodules that have historically broken due to dependency issues + """ + + def test_import_causal_trainer(self): + from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401 + HFCausalTrainerBuilder, + ) + + def test_import_rl_trainer(self): + from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401 + HFRLTrainerBuilder, + )