diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 7a9cf2fbb..02234d8b7 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader, RandomSampler from transformers.utils import is_torch_bf16_gpu_available from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder -from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first +from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = get_logger("axolotl") @@ -183,88 +183,88 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): sequence_len=cfg.sequence_len, min_sequence_len=cfg.min_sample_len or 2, ) - with zero_first(is_main_process()): - if cfg.is_preprocess: - min_input_len = np.min(get_dataset_lengths(train_dataset)) - LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True) - max_input_len = np.max(get_dataset_lengths(train_dataset)) - LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) - if cfg.model_config_type == "mamba": - LOG.info("dropping attention_mask column") - train_dataset = train_dataset.remove_columns("attention_mask") - if eval_dataset: - eval_dataset = eval_dataset.remove_columns("attention_mask") + if cfg.is_preprocess: + min_input_len = np.min(get_dataset_lengths(train_dataset)) + LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True) + max_input_len = np.max(get_dataset_lengths(train_dataset)) + LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) - if cfg.model_config_type == "falcon": - LOG.info("dropping token_type_ids column if it exists") - if "token_type_ids" in train_dataset.column_names: - train_dataset = train_dataset.remove_columns("token_type_ids") - if eval_dataset and "token_type_ids" in eval_dataset.column_names: - eval_dataset = eval_dataset.remove_columns("token_type_ids") + if cfg.model_config_type == "mamba": + LOG.info("dropping attention_mask column") + train_dataset = train_dataset.remove_columns("attention_mask") + if eval_dataset: + eval_dataset = eval_dataset.remove_columns("attention_mask") - train_dataset = train_dataset.filter( + if cfg.model_config_type == "falcon": + LOG.info("dropping token_type_ids column if it exists") + if "token_type_ids" in train_dataset.column_names: + train_dataset = train_dataset.remove_columns("token_type_ids") + if eval_dataset and "token_type_ids" in eval_dataset.column_names: + eval_dataset = eval_dataset.remove_columns("token_type_ids") + + train_dataset = train_dataset.filter( + drop_long, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Dropping Long Sequences", + ) + if eval_dataset: + eval_dataset = eval_dataset.filter( drop_long, num_proc=cfg.dataset_processes, load_from_cache_file=not cfg.is_preprocess, desc="Dropping Long Sequences", ) - if eval_dataset: - eval_dataset = eval_dataset.filter( - drop_long, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Long Sequences", - ) - if cfg.group_by_length: - train_dataset = train_dataset.map( - add_length, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Group By Length", - ) + if cfg.group_by_length: + train_dataset = train_dataset.map( + add_length, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Group By Length", + ) - if cfg.use_pose: - pose_kwargs = {} - if cfg.pose_num_chunks is not None: - pose_kwargs["chunks"] = cfg.pose_num_chunks - pose_fn = partial( - add_pose_position_ids, - max_context_len=cfg.pose_max_context_len, - split_on_token_ids=cfg.pose_split_on_token_ids, - **pose_kwargs, - ) - train_dataset = train_dataset.map( - pose_fn, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Add position_id column (PoSE)", - ) - train_dataset = train_dataset.sort("sequence_len") - if cfg.eval_sample_packing is not False: - if eval_dataset: - eval_dataset = eval_dataset.map( - pose_fn, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Add position_id column (PoSE)", - ) - elif cfg.sample_packing: - train_dataset = train_dataset.map( - add_position_ids, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Add position_id column (Sample Packing)", - ) - if cfg.eval_sample_packing is not False: - if eval_dataset: - eval_dataset = eval_dataset.map( - add_position_ids, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Add position_id column (Sample Packing)", - ) + if cfg.use_pose: + pose_kwargs = {} + if cfg.pose_num_chunks is not None: + pose_kwargs["chunks"] = cfg.pose_num_chunks + pose_fn = partial( + add_pose_position_ids, + max_context_len=cfg.pose_max_context_len, + split_on_token_ids=cfg.pose_split_on_token_ids, + **pose_kwargs, + ) + train_dataset = train_dataset.map( + pose_fn, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Add position_id column (PoSE)", + ) + train_dataset = train_dataset.sort("sequence_len") + if cfg.eval_sample_packing is not False: + if eval_dataset: + eval_dataset = eval_dataset.map( + pose_fn, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Add position_id column (PoSE)", + ) + elif cfg.sample_packing: + train_dataset = train_dataset.map( + add_position_ids, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Add position_id column (Sample Packing)", + ) + if cfg.eval_sample_packing is not False: + if eval_dataset: + eval_dataset = eval_dataset.map( + add_position_ids, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Add position_id column (Sample Packing)", + ) return train_dataset, eval_dataset