add SP doc, review comments
This commit is contained in:
@@ -590,7 +590,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
if not (self.args.sample_packing and not self.args.pretraining):
|
||||
if not self.args.sample_packing or self.args.pretraining:
|
||||
train_dataset = self._remove_unused_columns(
|
||||
train_dataset, description="training"
|
||||
)
|
||||
|
||||
@@ -166,10 +166,22 @@ def setup_signal_handler(
|
||||
)
|
||||
|
||||
|
||||
def train_context_manager(enable=False) -> contextlib.AbstractContextManager:
|
||||
"""Configure CUDA SDP kernel settings if enabled."""
|
||||
if enable:
|
||||
def train_context_manager(
|
||||
flash_optimum: bool = False,
|
||||
) -> contextlib.AbstractContextManager:
|
||||
"""
|
||||
Instantiate CUDA SDP kernel context manager if `flash_optimum` is `True`.
|
||||
|
||||
Args:
|
||||
flash_optimum: Whether to enable efficient backends for SDP attention.
|
||||
|
||||
Returns:
|
||||
Context manager for temporarily enabling efficient backends for SDP attention
|
||||
if `flash_optimum` is `True`, or `contextlib.nullcontext` otherwise.
|
||||
"""
|
||||
if flash_optimum:
|
||||
return torch.backends.cuda.sdp_kernel(
|
||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||
enable_flash=True,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=True,
|
||||
@@ -190,7 +202,7 @@ def execute_training(
|
||||
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
||||
"""
|
||||
LOG.info("Starting trainer...")
|
||||
context_manager = train_context_manager(cfg.flash_optimum)
|
||||
context_manager = train_context_manager(flash_optimum=cfg.flash_optimum)
|
||||
with context_manager:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
|
||||
@@ -17,19 +17,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def adjust_position_ids_for_slice(
|
||||
position_ids: list | torch.Tensor, start_idx: int
|
||||
position_ids: torch.Tensor, start_idx: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Adjust position IDs for a sliced sequence to maintain proper relative positions.
|
||||
This handles the case where position IDs might not be contiguous due to sample packing.
|
||||
This handles the case where position IDs might not be contiguous due to sample
|
||||
packing.
|
||||
"""
|
||||
# Convert to tensor if not already
|
||||
if not isinstance(position_ids, torch.Tensor):
|
||||
position_ids = torch.tensor(
|
||||
position_ids,
|
||||
device=position_ids.device if hasattr(position_ids, "device") else None,
|
||||
)
|
||||
|
||||
# Find the boundaries between samples (where position_ids reset)
|
||||
adjusted_pos_ids = position_ids.clone()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user