Compare commits

..

2 Commits

Author SHA1 Message Date
Dan Saunders
c4f4f81bed Merge branch 'main' into map-dataset-fetcher-fix 2025-06-26 11:20:05 -04:00
Dan Saunders
4ebd4aae3d handle possibly empty batch 2025-06-26 10:59:27 -04:00
17 changed files with 377 additions and 10005 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -75,17 +75,13 @@ def load_datasets(
num_examples = cli_args.debug_num_examples if cli_args else 1 num_examples = cli_args.debug_num_examples if cli_args else 1
text_only = cli_args.debug_text_only if cli_args else False text_only = cli_args.debug_text_only if cli_args else False
try: train_samples = sample_dataset(train_dataset, num_examples)
train_samples = sample_dataset(train_dataset, num_examples) check_dataset_labels(
check_dataset_labels( train_samples,
train_samples, tokenizer,
tokenizer, num_examples=num_examples,
num_examples=num_examples, text_only=text_only,
text_only=text_only, )
)
except AttributeError:
# can't sample iterable datasets
pass
LOG.info("printing prompters...") LOG.info("printing prompters...")
for prompter in prompters: for prompter in prompters:

View File

@@ -413,8 +413,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
or self.cfg.micro_batch_size > 1 or self.cfg.micro_batch_size > 1
): ):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn): return None
return None
if self.cfg.model_config_type == "mamba": if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer) return MambaDataCollator(tokenizer=self.tokenizer)

View File

@@ -776,9 +776,6 @@ class ModelLoader:
dist_dtype: torch.dtype, dist_dtype: torch.dtype,
before_kbit_train_or_finetune: bool, before_kbit_train_or_finetune: bool,
): ):
dest = {"dtype": dist_dtype}
if self.cfg.lora_on_cpu:
dest["device"] = "cpu"
for name, module in self.model.named_modules(): for name, module in self.model.named_modules():
if "norm" in name: if "norm" in name:
module.to(dist_dtype) module.to(dist_dtype)
@@ -789,4 +786,4 @@ class ModelLoader:
# don't upcast lm_head for btlm # don't upcast lm_head for btlm
continue continue
if any(m in name for m in embedding_modules) and hasattr(module, "weight"): if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
module.to(**dest) module.to(dist_dtype)

View File

@@ -9,6 +9,9 @@ from torch.utils.data._utils.worker import _worker_loop
class _MapDatasetFetcher(_BaseDatasetFetcher): class _MapDatasetFetcher(_BaseDatasetFetcher):
def fetch(self, possibly_batched_index): def fetch(self, possibly_batched_index):
if not possibly_batched_index:
return self.collate_fn([])
if isinstance(possibly_batched_index[0], list): if isinstance(possibly_batched_index[0], list):
data = [None for i in possibly_batched_index] data = [None for i in possibly_batched_index]
for i, possibly_batched_index_ in enumerate(possibly_batched_index): for i, possibly_batched_index_ in enumerate(possibly_batched_index):

View File

@@ -156,12 +156,8 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
model_cls_prefix = "".join( model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")] [part.capitalize() for part in model_type.split("_")]
) )
if model_type == "gemma3n": module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
module = __import__(module_path, fromlist=[f"{model_cls_prefix}TextAttention"]) attention_cls = getattr(module, f"{model_cls_prefix}Attention")
attention_cls = getattr(module, f"{model_cls_prefix}TextAttention")
else:
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
return attention_cls return attention_cls
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:

View File

@@ -42,10 +42,6 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
if has_remote_code: if has_remote_code:
patch_remote(model_name) patch_remote(model_name)
elif hasattr(transformers, "modeling_flash_attention_utils"): elif hasattr(transformers, "modeling_flash_attention_utils"):
# sanity check in case upstream api changes on this
assert hasattr(
transformers.modeling_flash_attention_utils, "_get_unpad_data"
), "transformers api changed for _get_unpad_data for flash attention"
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data get_unpad_data
) )

View File

@@ -103,7 +103,6 @@ class ChatTemplatePrompter(Prompter):
chat_template_kwargs = { chat_template_kwargs = {
"chat_template": self.chat_template, "chat_template": self.chat_template,
"add_generation_prompt": add_generation_prompt, "add_generation_prompt": add_generation_prompt,
**self.chat_template_kwargs,
} }
if tools: if tools:

View File

@@ -223,8 +223,6 @@ def execute_training(
) )
LOG.info("Starting trainer...") LOG.info("Starting trainer...")
if cfg.bf16:
torch.set_default_dtype(torch.bfloat16)
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -224,10 +224,10 @@ def wrap_pretraining_dataset(
remove_columns = [] remove_columns = []
if dataset.features is None: if dataset.features is None:
for first_row in dataset: for first_row in dataset:
remove_columns = list(first_row.keys()) remove_columns = first_row.keys()
break break
else: else:
remove_columns = list(dataset.features.keys()) remove_columns = dataset.features.keys()
dataset = dataset.map( dataset = dataset.map(
encode, encode,
@@ -267,7 +267,6 @@ def encode_packed_pretraining(
batch_size=1, batch_size=1,
batch_max_len=batch_size * max_seq_length, batch_max_len=batch_size * max_seq_length,
drop_last=True, drop_last=True,
num_processes=1,
) )
chunked_data = defaultdict(list) chunked_data = defaultdict(list)

View File

@@ -334,10 +334,7 @@ def _load_raw_datasets(
dataset = merge_datasets(datasets, cfg) dataset = merge_datasets(datasets, cfg)
if not cfg.skip_prepare_dataset: if not cfg.skip_prepare_dataset:
if split == "test" and cfg.eval_sequence_len: dataset = drop_long_seq_in_dataset(dataset, cfg)
dataset = drop_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
else:
dataset = drop_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
if cfg.sample_packing: if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None) dataset, _ = process_datasets_for_packing(cfg, dataset, None)

View File

@@ -148,14 +148,11 @@ def deduplicate_and_log_datasets(
return dataset, other_dataset return dataset, other_dataset
def drop_long_seq_in_dataset( def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
dataset: Dataset, sequence_len: int, cfg: DictDefault
) -> Dataset:
"""Remove sequences longer than configured maximum from dataset. """Remove sequences longer than configured maximum from dataset.
Args: Args:
dataset: Dataset to filter. dataset: Dataset to filter.
sequence_len: Maximum length for sequences to keep
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
Returns: Returns:
@@ -170,7 +167,7 @@ def drop_long_seq_in_dataset(
drop_long = functools.partial( drop_long = functools.partial(
drop_long_seq, drop_long_seq,
sequence_len=sequence_len, sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len, min_sequence_len=cfg.min_sample_len,
) )
@@ -190,7 +187,7 @@ def drop_long_seq_in_dataset(
drop_long_kwargs = {} drop_long_kwargs = {}
if filter_map_kwargs: if filter_map_kwargs:
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})" drop_long_kwargs["desc"] = "Dropping Long Sequences"
dataset = dataset.filter( dataset = dataset.filter(
drop_long, drop_long,

View File

@@ -260,7 +260,7 @@ class MultipackBatchSampler(BatchSampler):
lengths: np.ndarray, # Sequence lengths lengths: np.ndarray, # Sequence lengths
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = True, # Whether to drop final batches (might be incomplete) drop_last: bool = True, # Whether to drop final batches (might be incomplete)
num_count_samples: int = 4, # Number of times to estimate batch count num_count_samples: int = 8, # Number of times to estimate batch count
sequential: bool = False, # Whether to use sequential packing sequential: bool = False, # Whether to use sequential packing
group_size: int = 100_000, # Size of groups for parallel packing group_size: int = 100_000, # Size of groups for parallel packing
bin_size: int = 200, # The max number of samples that can be packed in a single bin bin_size: int = 200, # The max number of samples that can be packed in a single bin
@@ -335,13 +335,12 @@ class MultipackBatchSampler(BatchSampler):
bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins] bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins]
else: else:
# Use parallel packing # Use parallel packing
num_processes = self.num_processes or 1
all_bins = pack_parallel( all_bins = pack_parallel(
lengths, lengths,
bin_capacity=self.batch_max_len, bin_capacity=self.batch_max_len,
group_size=self.group_size, group_size=self.group_size,
bin_size=self.bin_size, bin_size=self.bin_size,
num_processes=min(4, num_processes) if num_processes else 4, num_processes=max(4, self.num_processes) if self.num_processes else 4,
safe_mode=self.safe_mode, safe_mode=self.safe_mode,
mp_start_method=self.mp_start_method, mp_start_method=self.mp_start_method,
) )

View File

@@ -366,12 +366,6 @@ class AxolotlInputConfig(
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048" "description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
}, },
) )
eval_sequence_len: int | None = Field(
default=None,
json_schema_extra={
"description": "The maximum length of an input for evaluation. If not specified, defaults to sequence_len"
},
)
min_sample_len: int | None = None min_sample_len: int | None = None
max_prompt_len: int = Field( max_prompt_len: int = Field(
default=512, default=512,
@@ -784,12 +778,6 @@ class AxolotlInputConfig(
"description": "Custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null." "description": "Custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null."
}, },
) )
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "Additional kwargs to pass to the chat template. This is useful for customizing the chat template. For example, you can pass `thinking=False` to add a generation prompt to the chat template."
},
)
eot_tokens: list[str] | None = Field( eot_tokens: list[str] | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={

View File

@@ -462,20 +462,6 @@ class TrainingValidationMixin:
return data return data
@model_validator(mode="before")
@classmethod
def pretrain_with_tps(cls, data):
if data.get("pretraining_dataset") and data.get(
"include_tokens_per_second", False
):
# combining these would raise `TypeError: cannot pickle 'dict_keys' object`
# due to trying to count the number of tokens total in the dataset
raise ValueError(
"pretraining_dataset and include_tokens_per_second cannot be used together."
)
return data
class LoRAValidationMixin: class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration.""" """Validation methods related to LoRA/QLoRA configuration."""

View File

@@ -381,7 +381,6 @@ def process_pretraining_datasets_for_packing(
if not skip_position_ids: if not skip_position_ids:
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
add_position_ids, add_position_ids,
batched=True,
desc="Add position_id column (Pretraining Sample Packing)", desc="Add position_id column (Pretraining Sample Packing)",
) )
if drop_attention_mask: if drop_attention_mask:

View File

@@ -70,7 +70,7 @@ class TestBatchedSamplerPacking:
) )
train_dataset = concatenate_datasets([dataset_wrapper]) train_dataset = concatenate_datasets([dataset_wrapper])
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg) train_dataset = drop_long_seq_in_dataset(train_dataset, cfg)
lengths = get_dataset_lengths(train_dataset) lengths = get_dataset_lengths(train_dataset)
batch_sampler = MultipackBatchSampler( batch_sampler = MultipackBatchSampler(