remove un-necessary zero-first guard as it's already only called in a parent fn (#1810) [skip ci]

This commit is contained in:
Wing Lian
2024-08-06 09:29:23 -04:00
committed by GitHub
parent ecdda006de
commit fbbeb4fee0

View File

@@ -16,7 +16,7 @@ from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger("axolotl") LOG = get_logger("axolotl")
@@ -183,88 +183,88 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
sequence_len=cfg.sequence_len, sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2, min_sequence_len=cfg.min_sample_len or 2,
) )
with zero_first(is_main_process()):
if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
if cfg.model_config_type == "mamba": if cfg.is_preprocess:
LOG.info("dropping attention_mask column") min_input_len = np.min(get_dataset_lengths(train_dataset))
train_dataset = train_dataset.remove_columns("attention_mask") LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
if eval_dataset: max_input_len = np.max(get_dataset_lengths(train_dataset))
eval_dataset = eval_dataset.remove_columns("attention_mask") LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
if cfg.model_config_type == "falcon": if cfg.model_config_type == "mamba":
LOG.info("dropping token_type_ids column if it exists") LOG.info("dropping attention_mask column")
if "token_type_ids" in train_dataset.column_names: train_dataset = train_dataset.remove_columns("attention_mask")
train_dataset = train_dataset.remove_columns("token_type_ids") if eval_dataset:
if eval_dataset and "token_type_ids" in eval_dataset.column_names: eval_dataset = eval_dataset.remove_columns("attention_mask")
eval_dataset = eval_dataset.remove_columns("token_type_ids")
train_dataset = train_dataset.filter( if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("token_type_ids")
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")
train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
if eval_dataset:
eval_dataset = eval_dataset.filter(
drop_long, drop_long,
num_proc=cfg.dataset_processes, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences", desc="Dropping Long Sequences",
) )
if eval_dataset:
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
if cfg.group_by_length: if cfg.group_by_length:
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
add_length, add_length,
num_proc=cfg.dataset_processes, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Group By Length", desc="Group By Length",
) )
if cfg.use_pose: if cfg.use_pose:
pose_kwargs = {} pose_kwargs = {}
if cfg.pose_num_chunks is not None: if cfg.pose_num_chunks is not None:
pose_kwargs["chunks"] = cfg.pose_num_chunks pose_kwargs["chunks"] = cfg.pose_num_chunks
pose_fn = partial( pose_fn = partial(
add_pose_position_ids, add_pose_position_ids,
max_context_len=cfg.pose_max_context_len, max_context_len=cfg.pose_max_context_len,
split_on_token_ids=cfg.pose_split_on_token_ids, split_on_token_ids=cfg.pose_split_on_token_ids,
**pose_kwargs, **pose_kwargs,
) )
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
pose_fn, pose_fn,
num_proc=cfg.dataset_processes, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)", desc="Add position_id column (PoSE)",
) )
train_dataset = train_dataset.sort("sequence_len") train_dataset = train_dataset.sort("sequence_len")
if cfg.eval_sample_packing is not False: if cfg.eval_sample_packing is not False:
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
pose_fn, pose_fn,
num_proc=cfg.dataset_processes, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)", desc="Add position_id column (PoSE)",
) )
elif cfg.sample_packing: elif cfg.sample_packing:
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
add_position_ids, add_position_ids,
num_proc=cfg.dataset_processes, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)", desc="Add position_id column (Sample Packing)",
) )
if cfg.eval_sample_packing is not False: if cfg.eval_sample_packing is not False:
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
add_position_ids, add_position_ids,
num_proc=cfg.dataset_processes, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)", desc="Add position_id column (Sample Packing)",
) )
return train_dataset, eval_dataset return train_dataset, eval_dataset