use accelerate logging for zero/main loggin only

This commit is contained in:
Wing Lian
2023-11-06 07:27:42 -05:00
parent 4c834bf25d
commit b2430ce670
2 changed files with 23 additions and 24 deletions

View File

@@ -1,6 +1,5 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import logging
import os
import signal
import sys
@@ -10,6 +9,7 @@ from typing import Optional
import torch
import transformers.modelcard
from accelerate.logging import get_logger
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
from transformers.deepspeed import is_deepspeed_zero3_enabled
@@ -18,7 +18,6 @@ from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.monkeypatch import neft_embeddings
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import zero_only
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
@@ -27,7 +26,7 @@ src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = logging.getLogger("axolotl.train")
LOG = get_logger("axolotl.train")
@dataclass
@@ -45,10 +44,10 @@ def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
):
# load the tokenizer first
with zero_only():
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}"
)
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
train_dataset = dataset_meta.train_dataset

View File

@@ -1,5 +1,4 @@
"""Module containing the Trainer class and related functions"""
import logging
import math
import os
from contextlib import contextmanager
@@ -10,6 +9,7 @@ import numpy as np
import torch
import torch.cuda
import torch.distributed as dist
from accelerate.logging import get_logger
from datasets import set_caching_enabled
from torch.utils.data import DistributedSampler, RandomSampler
@@ -21,10 +21,9 @@ from axolotl.utils.distributed import (
is_main_process,
reduce_and_broadcast,
zero_first,
zero_only,
)
LOG = logging.getLogger("axolotl")
LOG = get_logger("axolotl")
@torch.jit.script
@@ -160,8 +159,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
with zero_only():
LOG.debug(f"total_num_tokens: {total_num_tokens}")
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
cfg.total_num_tokens = total_num_tokens
if not cfg.total_supervised_tokens:
@@ -171,8 +169,10 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
.apply(lambda x: np.sum(np.array(x) != -100))
.sum()
)
with zero_only():
LOG.debug(f"`total_supervised_tokens: {total_supervised_tokens}`")
LOG.debug(
f"`total_supervised_tokens: {total_supervised_tokens}`",
main_process_only=True,
)
cfg.total_supervised_tokens = total_supervised_tokens
if cfg.sample_packing_eff_est:
@@ -191,10 +191,10 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
)
* cfg.num_epochs
)
with zero_only():
LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
)
LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}",
main_process_only=True,
)
else:
if cfg.world_size > 1 and is_distributed():
sampler = DistributedSampler(
@@ -223,8 +223,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
)
data_loader_len = data_loader.len_w_stats()
actual_eff = data_loader.efficiency()
with zero_only():
LOG.debug(f"data_loader_len: {data_loader_len}")
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
# FIXME: is there a bug here somewhere? the total num steps depends
# on the agreed on value for sample_packing_eff_est
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
@@ -241,14 +240,15 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
)
cfg.sample_packing_eff_est = sample_packing_eff_est
with zero_only():
LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
LOG.debug(
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
main_process_only=True,
)
else:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
with zero_only():
LOG.debug(f"total_num_steps: {total_num_steps}")
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
return total_num_steps