add SP doc, review comments

This commit is contained in:
Dan Saunders
2025-03-18 20:04:48 +00:00
parent 411df76a97
commit c1a58339e8
7 changed files with 115 additions and 17 deletions

View File

@@ -590,7 +590,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
if self.args.sample_packing and not self.args.pretraining:
train_dataset = train_dataset.remove_columns(["length"])
if not (self.args.sample_packing and not self.args.pretraining):
if not self.args.sample_packing or self.args.pretraining:
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)

View File

@@ -166,10 +166,22 @@ def setup_signal_handler(
)
def train_context_manager(enable=False) -> contextlib.AbstractContextManager:
"""Configure CUDA SDP kernel settings if enabled."""
if enable:
def train_context_manager(
flash_optimum: bool = False,
) -> contextlib.AbstractContextManager:
"""
Instantiate CUDA SDP kernel context manager if `flash_optimum` is `True`.
Args:
flash_optimum: Whether to enable efficient backends for SDP attention.
Returns:
Context manager for temporarily enabling efficient backends for SDP attention
if `flash_optimum` is `True`, or `contextlib.nullcontext` otherwise.
"""
if flash_optimum:
return torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
@@ -190,7 +202,7 @@ def execute_training(
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
LOG.info("Starting trainer...")
context_manager = train_context_manager(cfg.flash_optimum)
context_manager = train_context_manager(flash_optimum=cfg.flash_optimum)
with context_manager:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -17,19 +17,14 @@ logger = logging.getLogger(__name__)
def adjust_position_ids_for_slice(
position_ids: list | torch.Tensor, start_idx: int
position_ids: torch.Tensor, start_idx: int
) -> torch.Tensor:
"""
Adjust position IDs for a sliced sequence to maintain proper relative positions.
This handles the case where position IDs might not be contiguous due to sample packing.
This handles the case where position IDs might not be contiguous due to sample
packing.
"""
# Convert to tensor if not already
if not isinstance(position_ids, torch.Tensor):
position_ids = torch.tensor(
position_ids,
device=position_ids.device if hasattr(position_ids, "device") else None,
)
# Find the boundaries between samples (where position_ids reset)
adjusted_pos_ids = position_ids.clone()