From f9c7c3bb72a10a4fbedbe42359b65851dc76f66a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 26 Apr 2025 14:14:52 -0400 Subject: [PATCH] don't use is_main_process during config validation (#2569) --- .github/workflows/multi-gpu-e2e.yml | 1 + src/axolotl/utils/schemas/config.py | 25 +++++++++++-------------- src/axolotl/utils/trainer.py | 7 +++++++ tests/e2e/patched/test_sp.py | 5 ----- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml index aee4ddba6..2221bcfd4 100644 --- a/.github/workflows/multi-gpu-e2e.yml +++ b/.github/workflows/multi-gpu-e2e.yml @@ -8,6 +8,7 @@ on: - 'setup.py' - 'pyproject.toml' - '.github/workflows/multi-gpu-e2e.yml' + - 'src/axolotl/core/trainers/mixins/sequence_parallel.py' workflow_dispatch: schedule: - cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index f68d160df..2e0a6027c 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -18,7 +18,6 @@ from pydantic import ( ) from transformers.utils.import_utils import is_torch_npu_available -from axolotl.utils.distributed import is_main_process from axolotl.utils.schemas.datasets import ( DatasetConfig, DPODataset, @@ -719,10 +718,9 @@ class AxolotlInputConfig( and data.get("eval_sample_packing") is None and not data.get("eval_table_size") ): - if is_main_process(): - LOG.info( - "explicitly setting `eval_sample_packing` to match `sample_packing`" - ) + LOG.info( + "explicitly setting `eval_sample_packing` to match `sample_packing`" + ) data["eval_sample_packing"] = True if ( @@ -1179,15 +1177,14 @@ class AxolotlInputConfig( # TODO: monkeypatch / callback to average losses correctly across SP ranks # / fix gradient scaling across SP ranks. Losses, grads should be scaled # according to the proportion of non-padding tokens per rank. - if is_main_process(): - LOG.warning( - "Sequence parallelism (SP) is enabled with " - f"sequence_parallel_degree={self.sequence_parallel_degree}. " - "Please note that logged losses may differ slightly to the non-SP " - "losses due to transformers Trainer implementation details. " - "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " - "for more details." - ) + LOG.warning( + "Sequence parallelism (SP) is enabled with " + f"sequence_parallel_degree={self.sequence_parallel_degree}. " + "Please note that logged losses may differ slightly to the non-SP " + "losses due to transformers Trainer implementation details. " + "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " + "for more details." + ) return self diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 3dc9ae3f6..69aaabfa6 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -528,6 +528,13 @@ def setup_torch_compile_env(cfg): def setup_deepspeed_env(cfg, stage=None): from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig + from axolotl.utils.distributed import distributed_state + + if distributed_state and distributed_state.initialized: + raise RuntimeError( + "Distributed State already initialized before Deepspeed setup" + ) + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed if stage: diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 6e1e2f2cb..046c482e3 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -131,11 +131,6 @@ class TestConfigValidation: # Mock the ring_flash_attn module monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock()) - # Mock the is_main_process function to return True - monkeypatch.setattr( - "axolotl.utils.schemas.config.is_main_process", lambda: True - ) - @pytest.fixture def base_cfg(self): """Create a base configuration for testing."""