don't use is_main_process during config validation (#2569)

This commit is contained in:
Wing Lian
2025-04-26 14:14:52 -04:00
committed by GitHub
parent caf5cb63ea
commit f9c7c3bb72
4 changed files with 19 additions and 19 deletions

View File

@@ -8,6 +8,7 @@ on:
- 'setup.py' - 'setup.py'
- 'pyproject.toml' - 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml' - '.github/workflows/multi-gpu-e2e.yml'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
workflow_dispatch: workflow_dispatch:
schedule: schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday - cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday

View File

@@ -18,7 +18,6 @@ from pydantic import (
) )
from transformers.utils.import_utils import is_torch_npu_available from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.distributed import is_main_process
from axolotl.utils.schemas.datasets import ( from axolotl.utils.schemas.datasets import (
DatasetConfig, DatasetConfig,
DPODataset, DPODataset,
@@ -719,10 +718,9 @@ class AxolotlInputConfig(
and data.get("eval_sample_packing") is None and data.get("eval_sample_packing") is None
and not data.get("eval_table_size") and not data.get("eval_table_size")
): ):
if is_main_process(): LOG.info(
LOG.info( "explicitly setting `eval_sample_packing` to match `sample_packing`"
"explicitly setting `eval_sample_packing` to match `sample_packing`" )
)
data["eval_sample_packing"] = True data["eval_sample_packing"] = True
if ( if (
@@ -1179,15 +1177,14 @@ class AxolotlInputConfig(
# TODO: monkeypatch / callback to average losses correctly across SP ranks # TODO: monkeypatch / callback to average losses correctly across SP ranks
# / fix gradient scaling across SP ranks. Losses, grads should be scaled # / fix gradient scaling across SP ranks. Losses, grads should be scaled
# according to the proportion of non-padding tokens per rank. # according to the proportion of non-padding tokens per rank.
if is_main_process(): LOG.warning(
LOG.warning( "Sequence parallelism (SP) is enabled with "
"Sequence parallelism (SP) is enabled with " f"sequence_parallel_degree={self.sequence_parallel_degree}. "
f"sequence_parallel_degree={self.sequence_parallel_degree}. " "Please note that logged losses may differ slightly to the non-SP "
"Please note that logged losses may differ slightly to the non-SP " "losses due to transformers Trainer implementation details. "
"losses due to transformers Trainer implementation details. " "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " "for more details."
"for more details." )
)
return self return self

View File

@@ -528,6 +528,13 @@ def setup_torch_compile_env(cfg):
def setup_deepspeed_env(cfg, stage=None): def setup_deepspeed_env(cfg, stage=None):
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
from axolotl.utils.distributed import distributed_state
if distributed_state and distributed_state.initialized:
raise RuntimeError(
"Distributed State already initialized before Deepspeed setup"
)
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if stage: if stage:

View File

@@ -131,11 +131,6 @@ class TestConfigValidation:
# Mock the ring_flash_attn module # Mock the ring_flash_attn module
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock()) monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
# Mock the is_main_process function to return True
monkeypatch.setattr(
"axolotl.utils.schemas.config.is_main_process", lambda: True
)
@pytest.fixture @pytest.fixture
def base_cfg(self): def base_cfg(self):
"""Create a base configuration for testing.""" """Create a base configuration for testing."""