diff --git a/src/axolotl/train.py b/src/axolotl/train.py index b9b0e595d..5acfaf12c 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -18,6 +18,7 @@ from axolotl.common.cli import TrainerCliArgs from axolotl.logging_config import configure_logging from axolotl.monkeypatch import neft_embeddings from axolotl.utils.dict import DictDefault +from axolotl.utils.distributed import zero_only from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.trainer import setup_trainer @@ -44,7 +45,10 @@ def train( *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta ): # load the tokenizer first - LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") + with zero_only(): + LOG.debug( + f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}" + ) tokenizer = load_tokenizer(cfg) train_dataset = dataset_meta.train_dataset @@ -52,7 +56,10 @@ def train( total_num_steps = dataset_meta.total_num_steps # Load the model and tokenizer - LOG.info("loading model and (optionally) peft_config...") + msg = "loading model" + if cfg.adapter: + msg += " and peft_config..." + LOG.debug(msg) model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) safe_serialization = cfg.save_safetensors is True diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 9a1c689fb..313dd24e8 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -50,6 +50,17 @@ def get_world_size(): return int(os.getenv("WORLD_SIZE", "1")) +@contextmanager +def zero_only(): + """ + Context manager that only runs the enclosed block on the main rank. + """ + if is_main_process(): + yield + else: + yield None + + @contextmanager def zero_first(is_main): """ diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index d04390293..9a6afcd17 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -21,6 +21,7 @@ from axolotl.utils.distributed import ( is_main_process, reduce_and_broadcast, zero_first, + zero_only, ) LOG = logging.getLogger("axolotl") @@ -153,14 +154,14 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): # we have to drop anything longer then sequence len otherwise # flash attention with position ids fails if not cfg.total_num_tokens: - LOG.info("calculating total_num_tokens") total_num_tokens = np.sum( train_dataset.data.column("input_ids") .to_pandas() .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda .values ) - LOG.info(f"total_num_tokens: {total_num_tokens}") + with zero_only(): + LOG.debug(f"total_num_tokens: {total_num_tokens}") cfg.total_num_tokens = total_num_tokens if not cfg.total_supervised_tokens: @@ -170,7 +171,8 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): .apply(lambda x: np.sum(np.array(x) != -100)) .sum() ) - LOG.info(f"`total_supervised_tokens: {total_supervised_tokens}`") + with zero_only(): + LOG.debug(f"`total_supervised_tokens: {total_supervised_tokens}`") cfg.total_supervised_tokens = total_supervised_tokens if cfg.sample_packing_eff_est: @@ -189,9 +191,10 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): ) * cfg.num_epochs ) - LOG.info( - f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}" - ) + with zero_only(): + LOG.debug( + f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}" + ) else: if cfg.world_size > 1 and is_distributed(): sampler = DistributedSampler( @@ -220,7 +223,8 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): ) data_loader_len = data_loader.len_w_stats() actual_eff = data_loader.efficiency() - LOG.info(f"data_loader_len: {data_loader_len}") + with zero_only(): + LOG.debug(f"data_loader_len: {data_loader_len}") # FIXME: is there a bug here somewhere? the total num steps depends # on the agreed on value for sample_packing_eff_est total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs)) @@ -237,12 +241,14 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0 ) cfg.sample_packing_eff_est = sample_packing_eff_est - LOG.info(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}") + with zero_only(): + LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}") else: total_num_steps = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) - LOG.info(f"total_num_steps: {total_num_steps}") + with zero_only(): + LOG.debug(f"total_num_steps: {total_num_steps}") return total_num_steps