don't use is_main_process during config validation (#2569)
This commit is contained in:
1
.github/workflows/multi-gpu-e2e.yml
vendored
1
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -8,6 +8,7 @@ on:
|
|||||||
- 'setup.py'
|
- 'setup.py'
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
- '.github/workflows/multi-gpu-e2e.yml'
|
- '.github/workflows/multi-gpu-e2e.yml'
|
||||||
|
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
schedule:
|
schedule:
|
||||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from pydantic import (
|
|||||||
)
|
)
|
||||||
from transformers.utils.import_utils import is_torch_npu_available
|
from transformers.utils.import_utils import is_torch_npu_available
|
||||||
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
|
||||||
from axolotl.utils.schemas.datasets import (
|
from axolotl.utils.schemas.datasets import (
|
||||||
DatasetConfig,
|
DatasetConfig,
|
||||||
DPODataset,
|
DPODataset,
|
||||||
@@ -719,10 +718,9 @@ class AxolotlInputConfig(
|
|||||||
and data.get("eval_sample_packing") is None
|
and data.get("eval_sample_packing") is None
|
||||||
and not data.get("eval_table_size")
|
and not data.get("eval_table_size")
|
||||||
):
|
):
|
||||||
if is_main_process():
|
LOG.info(
|
||||||
LOG.info(
|
"explicitly setting `eval_sample_packing` to match `sample_packing`"
|
||||||
"explicitly setting `eval_sample_packing` to match `sample_packing`"
|
)
|
||||||
)
|
|
||||||
data["eval_sample_packing"] = True
|
data["eval_sample_packing"] = True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -1179,15 +1177,14 @@ class AxolotlInputConfig(
|
|||||||
# TODO: monkeypatch / callback to average losses correctly across SP ranks
|
# TODO: monkeypatch / callback to average losses correctly across SP ranks
|
||||||
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
|
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
|
||||||
# according to the proportion of non-padding tokens per rank.
|
# according to the proportion of non-padding tokens per rank.
|
||||||
if is_main_process():
|
LOG.warning(
|
||||||
LOG.warning(
|
"Sequence parallelism (SP) is enabled with "
|
||||||
"Sequence parallelism (SP) is enabled with "
|
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
|
||||||
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
|
"Please note that logged losses may differ slightly to the non-SP "
|
||||||
"Please note that logged losses may differ slightly to the non-SP "
|
"losses due to transformers Trainer implementation details. "
|
||||||
"losses due to transformers Trainer implementation details. "
|
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||||
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
"for more details."
|
||||||
"for more details."
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|||||||
@@ -528,6 +528,13 @@ def setup_torch_compile_env(cfg):
|
|||||||
def setup_deepspeed_env(cfg, stage=None):
|
def setup_deepspeed_env(cfg, stage=None):
|
||||||
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
|
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import distributed_state
|
||||||
|
|
||||||
|
if distributed_state and distributed_state.initialized:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Distributed State already initialized before Deepspeed setup"
|
||||||
|
)
|
||||||
|
|
||||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||||
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||||
if stage:
|
if stage:
|
||||||
|
|||||||
@@ -131,11 +131,6 @@ class TestConfigValidation:
|
|||||||
# Mock the ring_flash_attn module
|
# Mock the ring_flash_attn module
|
||||||
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
|
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
|
||||||
|
|
||||||
# Mock the is_main_process function to return True
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"axolotl.utils.schemas.config.is_main_process", lambda: True
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def base_cfg(self):
|
def base_cfg(self):
|
||||||
"""Create a base configuration for testing."""
|
"""Create a base configuration for testing."""
|
||||||
|
|||||||
Reference in New Issue
Block a user