Example for Slurm and various fixes (#3038) [skip ci]
* slurm example and make preprocess play nicely * start slurm if it init file exists * remove incorrect comment * feat: add slurm docs --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""Data handling specific to SFT."""
|
||||
|
||||
import functools
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Literal
|
||||
|
||||
@@ -104,6 +105,9 @@ def _prepare_standard_dataset(
|
||||
finally:
|
||||
loader.cleanup()
|
||||
|
||||
if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1":
|
||||
return train_dataset, eval_dataset, -1, prompters
|
||||
|
||||
# Validate sample packing configuration for evaluation
|
||||
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||
|
||||
@@ -51,7 +51,10 @@ def init_distributed_state():
|
||||
global distributed_state # pylint: disable=global-statement
|
||||
if distributed_state is None:
|
||||
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
|
||||
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
||||
try:
|
||||
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def get_distributed_state() -> PartialState | None:
|
||||
|
||||
Reference in New Issue
Block a user