Compare commits
6 Commits
diffusion-
...
sp-fix-mas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
954b989e88 | ||
|
|
c64c881460 | ||
|
|
cefd57cecb | ||
|
|
2f3c52ea2f | ||
|
|
741015b3cf | ||
|
|
4188700b7b |
@@ -235,6 +235,9 @@ class AxolotlTrainer(
|
|||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
|
|
||||||
# Return unprepared dataloader if using sequence parallelism
|
# Return unprepared dataloader if using sequence parallelism
|
||||||
|
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||||
|
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||||
|
# slice each batch along the sequence dimension).
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
|
|||||||
@@ -1,34 +1,22 @@
|
|||||||
"""Module for Axolotl trainer sequence parallelism mixin"""
|
"""Module for Axolotl trainer sequence parallelism mixin"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from torch import nn
|
|
||||||
from torch.utils.data import DistributedSampler, Sampler
|
from torch.utils.data import DistributedSampler, Sampler
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
try:
|
|
||||||
from ring_flash_attn import update_ring_flash_attn_params
|
|
||||||
except ImportError:
|
|
||||||
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
|
||||||
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SequenceParallelMixin:
|
class SequenceParallelMixin:
|
||||||
"""
|
"""
|
||||||
Mixin class for sequence parallelism support in trainers.
|
Mixin class for sequence parallelism support in trainers.
|
||||||
|
|
||||||
This mixin provides functionality for handling sequence parallelism,
|
This mixin provides functionality for handling sequence parallelism,
|
||||||
including creating appropriate samplers, managing data partitioning,
|
specifically for creating appropriate data samplers.
|
||||||
and updating ring flash attention parameters during training.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
@@ -99,84 +87,3 @@ class SequenceParallelMixin:
|
|||||||
return self._create_sequence_parallel_sampler(
|
return self._create_sequence_parallel_sampler(
|
||||||
eval_dataset, shuffle=False, is_eval=True
|
eval_dataset, shuffle=False, is_eval=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
|
|
||||||
"""
|
|
||||||
Calculate the cu_seqlens for the current forward pass and pass the value to
|
|
||||||
the substituted ring_flash_attn. This is accomplished by using the passed
|
|
||||||
`input_ids`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs: Current batch of inputs.
|
|
||||||
"""
|
|
||||||
# At this point, inputs should already be partitioned by the sequence
|
|
||||||
# parallel data collator
|
|
||||||
batch_size = inputs["input_ids"].shape[0]
|
|
||||||
seq_len = inputs["input_ids"].shape[1]
|
|
||||||
packed_seq_lens = [seq_len] * batch_size
|
|
||||||
|
|
||||||
# Calculate the full sequence length across all GPUs in this SP group
|
|
||||||
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
|
||||||
|
|
||||||
cu_seqlens = torch.cumsum(
|
|
||||||
torch.tensor(
|
|
||||||
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
|
||||||
),
|
|
||||||
dim=-1,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
cu_seqlens = F.pad(
|
|
||||||
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
|
||||||
)
|
|
||||||
|
|
||||||
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
|
||||||
|
|
||||||
def training_step(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
inputs: dict[str, torch.Tensor | Any],
|
|
||||||
num_items_in_batch: int | None = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Perform a training step on a batch of inputs. Overrides the
|
|
||||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
|
||||||
enabled.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Model to perform training step for.
|
|
||||||
inputs: Dictionary mapping.
|
|
||||||
"""
|
|
||||||
# Set up sequence parallelism for this step if enabled
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
self._update_ring_flash_attn_params(inputs)
|
|
||||||
|
|
||||||
# Proceed with normal training step
|
|
||||||
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
|
|
||||||
|
|
||||||
def prediction_step(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
inputs: dict[str, torch.Tensor | Any],
|
|
||||||
prediction_loss_only: bool,
|
|
||||||
ignore_keys: list[str] | None = None,
|
|
||||||
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
|
||||||
"""
|
|
||||||
Perform a prediction step on a batch of inputs. Overrides the
|
|
||||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
|
||||||
enabled.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Model to perform prediction step for.
|
|
||||||
inputs: Dictionary mapping of inputs.
|
|
||||||
prediction_loss_only: Whether to return only the loss.
|
|
||||||
ignore_keys: Keys to ignore in the inputs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (loss, logits, labels).
|
|
||||||
"""
|
|
||||||
# Set up sequence parallelism for this prediction step if enabled
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
self._update_ring_flash_attn_params(inputs)
|
|
||||||
|
|
||||||
# Proceed with normal prediction step
|
|
||||||
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
|
|
||||||
|
|||||||
@@ -6,10 +6,12 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
|
|||||||
their sequence parallel version of Flash Attention 2.
|
their sequence parallel version of Flash Attention 2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
@@ -98,3 +100,27 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
|||||||
substitute_hf_flash_attn(
|
substitute_hf_flash_attn(
|
||||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def update_ring_attn_params(batch: dict[str, torch.Tensor]):
|
||||||
|
"""
|
||||||
|
Calculate the cumulative sequence lengths for the current forward pass and pass the
|
||||||
|
value to the substituted `ring_flash_attn`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: A dictionary with a batch of data. May or may not contain `position_ids`
|
||||||
|
data; if not, we compute it.
|
||||||
|
"""
|
||||||
|
from ring_flash_attn import update_ring_flash_attn_params
|
||||||
|
|
||||||
|
input_ids = batch["input_ids"]
|
||||||
|
position_ids = batch.get("position_ids")
|
||||||
|
if position_ids is None:
|
||||||
|
seq_len = input_ids.shape[1]
|
||||||
|
position_ids = torch.arange(
|
||||||
|
0, seq_len, dtype=torch.long, device=input_ids.device
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
|
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
|
||||||
|
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||||
|
|||||||
@@ -96,7 +96,9 @@ def get_cu_seqlens(attn_mask):
|
|||||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||||
|
|
||||||
|
|
||||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
def get_cu_seqlens_from_pos_ids(
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
||||||
if len(position_ids.shape) == 1:
|
if len(position_ids.shape) == 1:
|
||||||
position_ids = position_ids.unsqueeze(0)
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ Data collators for axolotl to pad labels and position_ids for packed sequences.
|
|||||||
includes logic for handling sequence parallelism collation.
|
includes logic for handling sequence parallelism collation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
@@ -13,46 +12,7 @@ import torch.distributed as dist
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
|
||||||
|
|
||||||
|
|
||||||
def adjust_position_ids_for_slice(
|
|
||||||
position_ids: torch.Tensor, start_idx: int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Adjust position IDs for a sliced sequence to maintain proper relative positions.
|
|
||||||
This handles the case where position IDs might not be contiguous due to sample
|
|
||||||
packing.
|
|
||||||
"""
|
|
||||||
# Convert to tensor if not already
|
|
||||||
# Find the boundaries between samples (where position_ids reset)
|
|
||||||
adjusted_pos_ids = position_ids.clone()
|
|
||||||
|
|
||||||
# Process each sequence in the batch
|
|
||||||
for i in range(position_ids.shape[0]):
|
|
||||||
seq = position_ids[i]
|
|
||||||
|
|
||||||
# Find sample boundaries
|
|
||||||
boundaries = []
|
|
||||||
for j in range(1, len(seq)):
|
|
||||||
if seq[j] < seq[j - 1]:
|
|
||||||
boundaries.append(j)
|
|
||||||
|
|
||||||
# No need to adjust if there are no boundaries or this is a single sample
|
|
||||||
if not boundaries:
|
|
||||||
adjusted_pos_ids[i] = seq - start_idx
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Adjust each segment separately
|
|
||||||
prev_boundary = 0
|
|
||||||
for boundary in boundaries:
|
|
||||||
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
|
|
||||||
prev_boundary = boundary
|
|
||||||
|
|
||||||
# Last segment
|
|
||||||
adjusted_pos_ids[i, prev_boundary:] -= start_idx
|
|
||||||
|
|
||||||
return adjusted_pos_ids
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -196,23 +156,20 @@ class DataCollatorForSeq2Seq:
|
|||||||
Returns:
|
Returns:
|
||||||
Sliced batch dictionary.
|
Sliced batch dictionary.
|
||||||
"""
|
"""
|
||||||
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
# Get local (start, end) for sequence parallelism slicing
|
||||||
|
total_seq_len = batch["input_ids"].shape[1]
|
||||||
|
slice_size = total_seq_len // self.local_world_size
|
||||||
|
start = self.local_rank * slice_size
|
||||||
|
end = start + slice_size
|
||||||
|
|
||||||
|
# Update params for ring attention calculation
|
||||||
|
update_ring_attn_params(batch=batch)
|
||||||
|
|
||||||
|
# Slice batch for sequence parallel processing
|
||||||
|
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
||||||
for key in keys_to_slice:
|
for key in keys_to_slice:
|
||||||
if key in batch:
|
if key in batch:
|
||||||
seq_len = batch[key].shape[1]
|
batch[key] = batch[key][:, start:end]
|
||||||
slice_size = seq_len // self.local_world_size
|
|
||||||
start_idx = self.local_rank * slice_size
|
|
||||||
end_idx = (
|
|
||||||
start_idx + slice_size
|
|
||||||
if self.local_rank < self.local_world_size - 1
|
|
||||||
else seq_len
|
|
||||||
)
|
|
||||||
batch[key] = batch[key][:, start_idx:end_idx]
|
|
||||||
|
|
||||||
# Special handling for position_ids
|
|
||||||
if key == "position_ids" and self.local_rank > 0:
|
|
||||||
batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
|
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|||||||
@@ -1156,6 +1156,12 @@ class AxolotlInputConfig(
|
|||||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not info.data["micro_batch_size"] == 1:
|
||||||
|
raise ValueError(
|
||||||
|
"micro_batch_size must be set to 1 "
|
||||||
|
"due to a `ring-flash-attn` requirement"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||||
except ImportError as exception:
|
except ImportError as exception:
|
||||||
@@ -1165,6 +1171,18 @@ class AxolotlInputConfig(
|
|||||||
"or `pip install ring-flash-attn>=0.1.4`."
|
"or `pip install ring-flash-attn>=0.1.4`."
|
||||||
) from exception
|
) from exception
|
||||||
|
|
||||||
|
# TODO: monkeypatch / callback to average losses correctly across SP ranks
|
||||||
|
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
|
||||||
|
# according to the proportion of non-padding tokens per rank.
|
||||||
|
LOG.warning(
|
||||||
|
"Sequence parallelism (SP) is enabled with "
|
||||||
|
f"sequence_parallel_degree={value}. Please note that logged losses may "
|
||||||
|
"differ slightly to the non-SP losses due to transformers Trainer "
|
||||||
|
"implementation details. Please see "
|
||||||
|
"https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||||
|
"for more details."
|
||||||
|
)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
87
tests/e2e/multigpu/test_sp.py
Normal file
87
tests/e2e/multigpu/test_sp.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""E2E tests for sequence parallelism"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from ..utils import check_tensorboard
|
||||||
|
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSequenceParallelism:
|
||||||
|
"""Test case for training with sequence parallelism enabled"""
|
||||||
|
|
||||||
|
def test_sequence_parallel_training(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"load_in_8bit": False,
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"strict": False,
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"adapter": "qlora",
|
||||||
|
"sample_packing": True,
|
||||||
|
"eval_sample_packing": True,
|
||||||
|
"pad_to_sequence_len": True,
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||||
|
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 8,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"loss_watchdog_threshold": 5.0,
|
||||||
|
"loss_watchdog_patience": 3,
|
||||||
|
"bf16": "auto",
|
||||||
|
"warmup_steps": 1,
|
||||||
|
"saves_per_epoch": 1,
|
||||||
|
"logging_steps": 1,
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"sequence_parallel_degree": 2,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
|
||||||
|
)
|
||||||
@@ -12,7 +12,6 @@ from axolotl.monkeypatch.attention.ring_attn import (
|
|||||||
get_ring_attn_group,
|
get_ring_attn_group,
|
||||||
set_ring_attn_group,
|
set_ring_attn_group,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
@@ -48,33 +47,6 @@ def fixture_cfg():
|
|||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
class TestSequenceParallelHelpers:
|
|
||||||
"""Test helper functions used in sequence parallelism."""
|
|
||||||
|
|
||||||
def test_adjust_position_ids_for_slice(self, partial_state):
|
|
||||||
"""Test position_ids adjustment for sequence slices."""
|
|
||||||
# Create sample position_ids with multiple sequences
|
|
||||||
position_ids = torch.tensor(
|
|
||||||
[
|
|
||||||
# First sequence with 2 samples
|
|
||||||
[0, 1, 2, 3, 4, 0, 1, 2, 3],
|
|
||||||
# Second sequence with 3 samples
|
|
||||||
[0, 1, 2, 0, 1, 2, 3, 0, 1],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Adjust as if this was the second slice (start_idx = 4)
|
|
||||||
adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4)
|
|
||||||
|
|
||||||
# For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1]
|
|
||||||
# For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3]
|
|
||||||
expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4
|
|
||||||
expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4
|
|
||||||
|
|
||||||
assert torch.all(adjusted[0] == expected_first_seq)
|
|
||||||
assert torch.all(adjusted[1] == expected_second_seq)
|
|
||||||
|
|
||||||
|
|
||||||
class TestRingAttention:
|
class TestRingAttention:
|
||||||
"""Tests for the ring attention functionality."""
|
"""Tests for the ring attention functionality."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user