remove un-necessary zero-first guard as it's already only called in a parent fn (#1810) [skip ci]
This commit is contained in:
@@ -16,7 +16,7 @@ from torch.utils.data import DataLoader, RandomSampler
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
from axolotl.utils.distributed import reduce_and_broadcast
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
LOG = get_logger("axolotl")
|
LOG = get_logger("axolotl")
|
||||||
@@ -183,88 +183,88 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
sequence_len=cfg.sequence_len,
|
sequence_len=cfg.sequence_len,
|
||||||
min_sequence_len=cfg.min_sample_len or 2,
|
min_sequence_len=cfg.min_sample_len or 2,
|
||||||
)
|
)
|
||||||
with zero_first(is_main_process()):
|
|
||||||
if cfg.is_preprocess:
|
|
||||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
|
||||||
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
|
||||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
|
||||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
|
||||||
|
|
||||||
if cfg.model_config_type == "mamba":
|
if cfg.is_preprocess:
|
||||||
LOG.info("dropping attention_mask column")
|
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
||||||
if eval_dataset:
|
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
||||||
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
||||||
|
|
||||||
if cfg.model_config_type == "falcon":
|
if cfg.model_config_type == "mamba":
|
||||||
LOG.info("dropping token_type_ids column if it exists")
|
LOG.info("dropping attention_mask column")
|
||||||
if "token_type_ids" in train_dataset.column_names:
|
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||||
train_dataset = train_dataset.remove_columns("token_type_ids")
|
if eval_dataset:
|
||||||
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
|
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
||||||
eval_dataset = eval_dataset.remove_columns("token_type_ids")
|
|
||||||
|
|
||||||
train_dataset = train_dataset.filter(
|
if cfg.model_config_type == "falcon":
|
||||||
|
LOG.info("dropping token_type_ids column if it exists")
|
||||||
|
if "token_type_ids" in train_dataset.column_names:
|
||||||
|
train_dataset = train_dataset.remove_columns("token_type_ids")
|
||||||
|
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
|
||||||
|
eval_dataset = eval_dataset.remove_columns("token_type_ids")
|
||||||
|
|
||||||
|
train_dataset = train_dataset.filter(
|
||||||
|
drop_long,
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Dropping Long Sequences",
|
||||||
|
)
|
||||||
|
if eval_dataset:
|
||||||
|
eval_dataset = eval_dataset.filter(
|
||||||
drop_long,
|
drop_long,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Dropping Long Sequences",
|
desc="Dropping Long Sequences",
|
||||||
)
|
)
|
||||||
if eval_dataset:
|
|
||||||
eval_dataset = eval_dataset.filter(
|
|
||||||
drop_long,
|
|
||||||
num_proc=cfg.dataset_processes,
|
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
|
||||||
desc="Dropping Long Sequences",
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.group_by_length:
|
if cfg.group_by_length:
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
add_length,
|
add_length,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Group By Length",
|
desc="Group By Length",
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.use_pose:
|
if cfg.use_pose:
|
||||||
pose_kwargs = {}
|
pose_kwargs = {}
|
||||||
if cfg.pose_num_chunks is not None:
|
if cfg.pose_num_chunks is not None:
|
||||||
pose_kwargs["chunks"] = cfg.pose_num_chunks
|
pose_kwargs["chunks"] = cfg.pose_num_chunks
|
||||||
pose_fn = partial(
|
pose_fn = partial(
|
||||||
add_pose_position_ids,
|
add_pose_position_ids,
|
||||||
max_context_len=cfg.pose_max_context_len,
|
max_context_len=cfg.pose_max_context_len,
|
||||||
split_on_token_ids=cfg.pose_split_on_token_ids,
|
split_on_token_ids=cfg.pose_split_on_token_ids,
|
||||||
**pose_kwargs,
|
**pose_kwargs,
|
||||||
)
|
)
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
pose_fn,
|
pose_fn,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (PoSE)",
|
desc="Add position_id column (PoSE)",
|
||||||
)
|
)
|
||||||
train_dataset = train_dataset.sort("sequence_len")
|
train_dataset = train_dataset.sort("sequence_len")
|
||||||
if cfg.eval_sample_packing is not False:
|
if cfg.eval_sample_packing is not False:
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
pose_fn,
|
pose_fn,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (PoSE)",
|
desc="Add position_id column (PoSE)",
|
||||||
)
|
)
|
||||||
elif cfg.sample_packing:
|
elif cfg.sample_packing:
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (Sample Packing)",
|
desc="Add position_id column (Sample Packing)",
|
||||||
)
|
)
|
||||||
if cfg.eval_sample_packing is not False:
|
if cfg.eval_sample_packing is not False:
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (Sample Packing)",
|
desc="Add position_id column (Sample Packing)",
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user