use accelerate logging for zero/main loggin only

This commit is contained in:
Wing Lian
2023-11-06 07:27:42 -05:00
parent 4c834bf25d
commit b2430ce670
2 changed files with 23 additions and 24 deletions

View File

@@ -1,6 +1,5 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import logging
import os import os
import signal import signal
import sys import sys
@@ -10,6 +9,7 @@ from typing import Optional
import torch import torch
import transformers.modelcard import transformers.modelcard
from accelerate.logging import get_logger
from datasets import Dataset from datasets import Dataset
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from transformers.deepspeed import is_deepspeed_zero3_enabled from transformers.deepspeed import is_deepspeed_zero3_enabled
@@ -18,7 +18,6 @@ from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.monkeypatch import neft_embeddings from axolotl.monkeypatch import neft_embeddings
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import zero_only
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
@@ -27,7 +26,7 @@ src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir) sys.path.insert(0, src_dir)
configure_logging() configure_logging()
LOG = logging.getLogger("axolotl.train") LOG = get_logger("axolotl.train")
@dataclass @dataclass
@@ -45,10 +44,10 @@ def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
): ):
# load the tokenizer first # load the tokenizer first
with zero_only(): LOG.debug(
LOG.debug( f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}" main_process_only=True,
) )
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
train_dataset = dataset_meta.train_dataset train_dataset = dataset_meta.train_dataset

View File

@@ -1,5 +1,4 @@
"""Module containing the Trainer class and related functions""" """Module containing the Trainer class and related functions"""
import logging
import math import math
import os import os
from contextlib import contextmanager from contextlib import contextmanager
@@ -10,6 +9,7 @@ import numpy as np
import torch import torch
import torch.cuda import torch.cuda
import torch.distributed as dist import torch.distributed as dist
from accelerate.logging import get_logger
from datasets import set_caching_enabled from datasets import set_caching_enabled
from torch.utils.data import DistributedSampler, RandomSampler from torch.utils.data import DistributedSampler, RandomSampler
@@ -21,10 +21,9 @@ from axolotl.utils.distributed import (
is_main_process, is_main_process,
reduce_and_broadcast, reduce_and_broadcast,
zero_first, zero_first,
zero_only,
) )
LOG = logging.getLogger("axolotl") LOG = get_logger("axolotl")
@torch.jit.script @torch.jit.script
@@ -160,8 +159,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values .values
) )
with zero_only(): LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
LOG.debug(f"total_num_tokens: {total_num_tokens}")
cfg.total_num_tokens = total_num_tokens cfg.total_num_tokens = total_num_tokens
if not cfg.total_supervised_tokens: if not cfg.total_supervised_tokens:
@@ -171,8 +169,10 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
.apply(lambda x: np.sum(np.array(x) != -100)) .apply(lambda x: np.sum(np.array(x) != -100))
.sum() .sum()
) )
with zero_only(): LOG.debug(
LOG.debug(f"`total_supervised_tokens: {total_supervised_tokens}`") f"`total_supervised_tokens: {total_supervised_tokens}`",
main_process_only=True,
)
cfg.total_supervised_tokens = total_supervised_tokens cfg.total_supervised_tokens = total_supervised_tokens
if cfg.sample_packing_eff_est: if cfg.sample_packing_eff_est:
@@ -191,10 +191,10 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
) )
* cfg.num_epochs * cfg.num_epochs
) )
with zero_only(): LOG.debug(
LOG.debug( f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}",
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}" main_process_only=True,
) )
else: else:
if cfg.world_size > 1 and is_distributed(): if cfg.world_size > 1 and is_distributed():
sampler = DistributedSampler( sampler = DistributedSampler(
@@ -223,8 +223,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
) )
data_loader_len = data_loader.len_w_stats() data_loader_len = data_loader.len_w_stats()
actual_eff = data_loader.efficiency() actual_eff = data_loader.efficiency()
with zero_only(): LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
LOG.debug(f"data_loader_len: {data_loader_len}")
# FIXME: is there a bug here somewhere? the total num steps depends # FIXME: is there a bug here somewhere? the total num steps depends
# on the agreed on value for sample_packing_eff_est # on the agreed on value for sample_packing_eff_est
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs)) total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
@@ -241,14 +240,15 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0 math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
) )
cfg.sample_packing_eff_est = sample_packing_eff_est cfg.sample_packing_eff_est = sample_packing_eff_est
with zero_only(): LOG.debug(
LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}") f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
main_process_only=True,
)
else: else:
total_num_steps = int( total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
) )
with zero_only(): LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
LOG.debug(f"total_num_steps: {total_num_steps}")
return total_num_steps return total_num_steps