don't use is_main_process during config validation (#2569)
This commit is contained in:
1
.github/workflows/multi-gpu-e2e.yml
vendored
1
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -8,6 +8,7 @@ on:
|
||||
- 'setup.py'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/multi-gpu-e2e.yml'
|
||||
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||
|
||||
@@ -18,7 +18,6 @@ from pydantic import (
|
||||
)
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.schemas.datasets import (
|
||||
DatasetConfig,
|
||||
DPODataset,
|
||||
@@ -719,10 +718,9 @@ class AxolotlInputConfig(
|
||||
and data.get("eval_sample_packing") is None
|
||||
and not data.get("eval_table_size")
|
||||
):
|
||||
if is_main_process():
|
||||
LOG.info(
|
||||
"explicitly setting `eval_sample_packing` to match `sample_packing`"
|
||||
)
|
||||
LOG.info(
|
||||
"explicitly setting `eval_sample_packing` to match `sample_packing`"
|
||||
)
|
||||
data["eval_sample_packing"] = True
|
||||
|
||||
if (
|
||||
@@ -1179,15 +1177,14 @@ class AxolotlInputConfig(
|
||||
# TODO: monkeypatch / callback to average losses correctly across SP ranks
|
||||
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
|
||||
# according to the proportion of non-padding tokens per rank.
|
||||
if is_main_process():
|
||||
LOG.warning(
|
||||
"Sequence parallelism (SP) is enabled with "
|
||||
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
|
||||
"Please note that logged losses may differ slightly to the non-SP "
|
||||
"losses due to transformers Trainer implementation details. "
|
||||
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||
"for more details."
|
||||
)
|
||||
LOG.warning(
|
||||
"Sequence parallelism (SP) is enabled with "
|
||||
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
|
||||
"Please note that logged losses may differ slightly to the non-SP "
|
||||
"losses due to transformers Trainer implementation details. "
|
||||
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||
"for more details."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@@ -528,6 +528,13 @@ def setup_torch_compile_env(cfg):
|
||||
def setup_deepspeed_env(cfg, stage=None):
|
||||
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
|
||||
|
||||
from axolotl.utils.distributed import distributed_state
|
||||
|
||||
if distributed_state and distributed_state.initialized:
|
||||
raise RuntimeError(
|
||||
"Distributed State already initialized before Deepspeed setup"
|
||||
)
|
||||
|
||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||
if stage:
|
||||
|
||||
@@ -131,11 +131,6 @@ class TestConfigValidation:
|
||||
# Mock the ring_flash_attn module
|
||||
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
|
||||
|
||||
# Mock the is_main_process function to return True
|
||||
monkeypatch.setattr(
|
||||
"axolotl.utils.schemas.config.is_main_process", lambda: True
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def base_cfg(self):
|
||||
"""Create a base configuration for testing."""
|
||||
|
||||
Reference in New Issue
Block a user