precommit fixes
This commit is contained in:
@@ -413,13 +413,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
if self.args.curriculum_sampling:
|
||||
sampler = SequentialSampler(self.train_dataset)
|
||||
else:
|
||||
generator = None
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(self.args.seed)
|
||||
sampler = RandomSampler(
|
||||
self.train_dataset, generator=generator
|
||||
)
|
||||
sampler = RandomSampler(self.train_dataset)
|
||||
|
||||
return MultipackBatchSampler(
|
||||
sampler,
|
||||
@@ -434,13 +428,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
if self.args.curriculum_sampling:
|
||||
return SequentialSampler(self.train_dataset)
|
||||
|
||||
sampler = super()._get_train_sampler()
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(self.args.seed)
|
||||
sampler.generator = generator
|
||||
|
||||
return sampler
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
@@ -498,28 +486,30 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
self.accelerator.even_batches = False
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
return DataLoader(train_dataset, **dataloader_params)
|
||||
else:
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
else:
|
||||
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||
train_dataset = self._remove_unused_columns(train_dataset, description="training")
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
|
||||
|
||||
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_train_sampler()
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
return DataLoader(train_dataset, **dataloader_params)
|
||||
else:
|
||||
return self.accelerator.prepare(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||
train_dataset = self._remove_unused_columns(
|
||||
train_dataset, description="training"
|
||||
)
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(
|
||||
data_collator, description="training"
|
||||
)
|
||||
|
||||
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_train_sampler()
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
return DataLoader(train_dataset, **dataloader_params)
|
||||
|
||||
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||
@@ -842,7 +832,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
num_items_in_batch: int = None,
|
||||
num_items_in_batch: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs.
|
||||
|
||||
@@ -218,7 +218,9 @@ class DataCollatorForSeq2Seq:
|
||||
f"Non-padding tokens: {non_pad_tokens_total}"
|
||||
)
|
||||
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
|
||||
logger.info(f"GPU {self.rank} start_ids:end_idx: {start_idx}:{end_idx}")
|
||||
logger.info(
|
||||
f"GPU {self.rank} start_ids:end_idx: {start_idx}:{end_idx}"
|
||||
)
|
||||
|
||||
batch[key] = batch[key][:, start_idx:end_idx]
|
||||
|
||||
|
||||
@@ -11,10 +11,7 @@ from accelerate.state import PartialState
|
||||
ring_flash_attn_mock = MagicMock()
|
||||
with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
from axolotl.utils.collators.sequence_parallel import (
|
||||
adjust_position_ids_for_slice,
|
||||
check_for_boundary_splits,
|
||||
)
|
||||
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
||||
|
||||
|
||||
# Create a fixture for PartialState
|
||||
@@ -52,37 +49,6 @@ class TestSequenceParallelHelpers:
|
||||
assert torch.all(adjusted[0] == expected_first_seq)
|
||||
assert torch.all(adjusted[1] == expected_second_seq)
|
||||
|
||||
def test_check_for_boundary_splits(self):
|
||||
"""Test detection of boundaries near slice edges."""
|
||||
# Boundaries at positions 10, 25, 40
|
||||
boundaries = [10, 25, 40]
|
||||
|
||||
# Test case where two boundaries are near edges (one at start, one at end)
|
||||
problems = check_for_boundary_splits(boundaries, slice_start=8, slice_end=30)
|
||||
assert (
|
||||
len(problems) == 2
|
||||
) # Both boundary at 10 (near start) and 25 (near end) are problems
|
||||
|
||||
# Check first problem - boundary near start
|
||||
assert problems[0][0] == 10 # The boundary position
|
||||
assert problems[0][1] == "start" # Type of issue
|
||||
assert problems[0][2] == 2 # Distance from start
|
||||
|
||||
# Check second problem - boundary near end
|
||||
assert problems[1][0] == 25 # The boundary position
|
||||
assert problems[1][1] == "end" # Type of issue
|
||||
assert problems[1][2] == 5 # Distance from end
|
||||
|
||||
# Test case with only one problem at the end
|
||||
problems = check_for_boundary_splits(boundaries, slice_start=15, slice_end=27)
|
||||
assert len(problems) == 1 # Only boundary at 25 is near the end
|
||||
assert problems[0][0] == 25 # The boundary
|
||||
assert problems[0][1] == "end" # Type of issue
|
||||
|
||||
# Test case with no problems
|
||||
problems = check_for_boundary_splits(boundaries, slice_start=12, slice_end=20)
|
||||
assert len(problems) == 0
|
||||
|
||||
|
||||
class TestRingAttention:
|
||||
"""Tests for the ring attention functionality."""
|
||||
|
||||
Reference in New Issue
Block a user