Compare commits
5 Commits
seq-parall
...
822a8a6931
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
822a8a6931 | ||
|
|
1a51180637 | ||
|
|
7562aadf89 | ||
|
|
479f5e18dd | ||
|
|
945dcc5020 |
4
.github/workflows/multi-gpu-e2e.yml
vendored
4
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -4,10 +4,6 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- 'tests/e2e/multigpu/*.py'
|
- 'tests/e2e/multigpu/*.py'
|
||||||
- 'requirements.txt'
|
|
||||||
- 'setup.py'
|
|
||||||
- 'pyproject.toml'
|
|
||||||
- '.github/workflows/multi-gpu-e2e.yml'
|
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
schedule:
|
schedule:
|
||||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||||
|
|||||||
@@ -37,11 +37,15 @@ temp_dir = tempfile.mkdtemp()
|
|||||||
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||||
f.write(dockerfile_contents)
|
f.write(dockerfile_contents)
|
||||||
|
|
||||||
cicd_image = Image.from_dockerfile(
|
cicd_image = (
|
||||||
|
Image.from_dockerfile(
|
||||||
pathlib.Path(temp_dir) / "Dockerfile",
|
pathlib.Path(temp_dir) / "Dockerfile",
|
||||||
force_build=True,
|
force_build=True,
|
||||||
gpu="A10G",
|
gpu="A10G",
|
||||||
).env(df_args)
|
)
|
||||||
|
.env(df_args)
|
||||||
|
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||||
|
)
|
||||||
|
|
||||||
app = App("Axolotl CI/CD", secrets=[])
|
app = App("Axolotl CI/CD", secrets=[])
|
||||||
|
|
||||||
|
|||||||
@@ -407,10 +407,7 @@ save_total_limit: # Checkpoints saved at a time
|
|||||||
max_steps:
|
max_steps:
|
||||||
|
|
||||||
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
|
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
|
||||||
include_tokens_per_second: # Optional[bool]
|
include_tokens_per_second:
|
||||||
|
|
||||||
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
|
|
||||||
auto_find_batch_size: # Optional[bool]
|
|
||||||
|
|
||||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ to leverage operator fusion and tensor re-use in order to improve speed and redu
|
|||||||
memory usage during the forward and backward passes of these calculations.
|
memory usage during the forward and backward passes of these calculations.
|
||||||
|
|
||||||
We currently support several common model architectures, including (but not limited to):
|
We currently support several common model architectures, including (but not limited to):
|
||||||
|
|
||||||
- `llama`
|
- `llama`
|
||||||
- `mistral`
|
- `mistral`
|
||||||
- `qwen2`
|
- `qwen2`
|
||||||
|
|||||||
@@ -13,12 +13,12 @@ liger-kernel==0.5.2
|
|||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
transformers==4.49.0
|
transformers==4.48.3
|
||||||
tokenizers>=0.21.0
|
tokenizers>=0.21.0
|
||||||
accelerate==1.3.0
|
accelerate==1.3.0
|
||||||
datasets==3.2.0
|
datasets==3.2.0
|
||||||
deepspeed==0.16.1
|
deepspeed==0.16.1
|
||||||
trl==0.15.1
|
trl==0.15.0
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
|
|||||||
@@ -123,6 +123,8 @@ class ModalCloud(Cloud):
|
|||||||
if env := self.get_env():
|
if env := self.get_env():
|
||||||
image = image.env(env)
|
image = image.env(env)
|
||||||
|
|
||||||
|
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def get_secrets(self):
|
def get_secrets(self):
|
||||||
|
|||||||
@@ -59,7 +59,6 @@ from axolotl.core.training_args import (
|
|||||||
AxolotlTrainingArguments,
|
AxolotlTrainingArguments,
|
||||||
)
|
)
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType, get_extract_fn
|
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
@@ -747,11 +746,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||||
if self.cfg.sp_ulysses_degree:
|
|
||||||
data_collator_kwargs["sp_extract_fn"] = get_extract_fn(
|
|
||||||
USPRingAttnType.ZIGZAG,
|
|
||||||
sp_ulysses_degree=self.cfg.sp_ulysses_degree
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
data_collator_kwargs["max_length"] = self.cfg.sequence_len
|
data_collator_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
if is_peft_model(unwrapped_model):
|
if is_peft_model(unwrapped_model):
|
||||||
unwrapped_model.merge_adapter()
|
unwrapped_model.merge_adapter()
|
||||||
state_dict = unwrapped_model.state_dict()
|
state_dict = unwrapped_model.state_dict()
|
||||||
|
unwrapped_model.unmerge_adapter()
|
||||||
# Remove base_model and base_layer prefixes
|
# Remove base_model and base_layer prefixes
|
||||||
state_dict = {
|
state_dict = {
|
||||||
k.removeprefix("base_model.model.")
|
k.removeprefix("base_model.model.")
|
||||||
@@ -104,5 +105,3 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||||
)
|
)
|
||||||
llm_model.load_weights(state_dict.items())
|
llm_model.load_weights(state_dict.items())
|
||||||
if is_peft_model(unwrapped_model):
|
|
||||||
unwrapped_model.unmerge_adapter()
|
|
||||||
|
|||||||
@@ -206,16 +206,6 @@ class AxolotlTrainingMixins:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
sp_ulysses_degree: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Ulysses parallelism for hybrid sequence parallel long context attn"},
|
|
||||||
)
|
|
||||||
|
|
||||||
sp_ring_degree: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Ring attention parallelism for sequence parallel long context attn"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||||
|
|||||||
@@ -1,45 +0,0 @@
|
|||||||
from enum import Enum
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
||||||
from yunchang import set_seq_parallel_pg, EXTRACT_FUNC_DICT
|
|
||||||
|
|
||||||
from axolotl.utils.distributed import get_world_size, get_rank
|
|
||||||
|
|
||||||
|
|
||||||
class USPRingAttnType(Enum):
|
|
||||||
BASIC = "basic"
|
|
||||||
ZIGZAG = "zigzag"
|
|
||||||
STRIPE = "stripe"
|
|
||||||
|
|
||||||
def apply_usp_attn_patch(ring_impl_type: USPRingAttnType):
|
|
||||||
from axolotl.monkeypatch.attention.sequence_parallel.usp import build_usp_fa_forward
|
|
||||||
|
|
||||||
fa_forward = build_usp_fa_forward(ring_impl_type)
|
|
||||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = fa_forward
|
|
||||||
|
|
||||||
def get_extract_fn(ring_impl_type: USPRingAttnType, sp_ulysses_degree: int):
|
|
||||||
fn = EXTRACT_FUNC_DICT["basic"]
|
|
||||||
if ring_impl_type.value in EXTRACT_FUNC_DICT.keys():
|
|
||||||
fn = EXTRACT_FUNC_DICT[ring_impl_type.value]
|
|
||||||
|
|
||||||
# map bad key upstream
|
|
||||||
elif ring_impl_type == USPRingAttnType.STRIPE:
|
|
||||||
fn = EXTRACT_FUNC_DICT["strip"]
|
|
||||||
|
|
||||||
world_size = get_world_size()
|
|
||||||
rd = world_size // sp_ulysses_degree
|
|
||||||
|
|
||||||
return partial(fn, rank=get_rank(), world_size=world_size, rd=rd, ud=sp_ulysses_degree)
|
|
||||||
|
|
||||||
def set_usp_parallel_group(sp_ulysses_degree):
|
|
||||||
"""
|
|
||||||
setup distributed parallel group for USP attention
|
|
||||||
make sure this gets called before building any USP attention modules
|
|
||||||
:param sp_ulysses_degree:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
world_size = get_world_size()
|
|
||||||
rank = get_rank()
|
|
||||||
sp_ring_degree = world_size // sp_ulysses_degree
|
|
||||||
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
from enum import Enum
|
|
||||||
from typing import Optional, Tuple, Callable
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from yunchang import LongContextAttention
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType
|
|
||||||
|
|
||||||
|
|
||||||
def build_usp_fa_forward(ring_impl_type: USPRingAttnType) -> Callable:
|
|
||||||
usp_attn = LongContextAttention(ring_impl_type.value)
|
|
||||||
|
|
||||||
def flash_attention_forward(
|
|
||||||
module: torch.nn.Module, # pylint: disable=unused-argument
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor], # pylint: disable=unused-argument
|
|
||||||
dropout: float = 0.0,
|
|
||||||
scaling: Optional[float] = None,
|
|
||||||
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
|
|
||||||
softcap: Optional[float] = None,
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
) -> Tuple[torch.Tensor, None]:
|
|
||||||
attn_output = usp_attn(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
dropout_p=dropout,
|
|
||||||
softmax_scale=scaling,
|
|
||||||
causal=True,
|
|
||||||
softcap=softcap,
|
|
||||||
)
|
|
||||||
return attn_output, None
|
|
||||||
|
|
||||||
return flash_attention_forward
|
|
||||||
@@ -4,13 +4,12 @@ import importlib
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import types
|
import types
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from peft import PeftModelForCausalLM
|
from peft import PeftModelForCausalLM
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoConfig
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
from axolotl.kernels.lora import (
|
from axolotl.kernels.lora import (
|
||||||
apply_lora_mlp_geglu,
|
apply_lora_mlp_geglu,
|
||||||
@@ -96,82 +95,61 @@ def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tens
|
|||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
|
||||||
"""
|
|
||||||
Get the appropriate attention class by inspecting the model config.
|
|
||||||
Uses dynamic import to support any model architecture that follows
|
|
||||||
the standard transformers naming convention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The appropriate attention class for the model.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If `base_model` not specified or attention class cannot be imported
|
|
||||||
ImportError: If the model module or attention class doesn't exist
|
|
||||||
"""
|
|
||||||
if "base_model" not in cfg:
|
|
||||||
raise ValueError("base_model must be specified in config")
|
|
||||||
|
|
||||||
# Get model config without loading the model
|
|
||||||
model_config = AutoConfig.from_pretrained(cfg["base_model"])
|
|
||||||
model_type = model_config.model_type
|
|
||||||
|
|
||||||
# Special case for model_type = "qwen2"
|
|
||||||
if model_type == "qwen2":
|
|
||||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
|
|
||||||
|
|
||||||
return Qwen2Attention
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Dynamically import the module and attention class
|
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
|
||||||
module = __import__(
|
|
||||||
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
|
|
||||||
)
|
|
||||||
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
|
|
||||||
|
|
||||||
return attention_cls
|
|
||||||
except (ImportError, AttributeError) as e:
|
|
||||||
raise ValueError(
|
|
||||||
f"Could not import attention class for model_type: {model_type}. "
|
|
||||||
f"Error: {str(e)}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
def patch_self_attn_lora(cfg: DictDefault):
|
def patch_self_attn_lora(model: PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
Given an `axolotl` config, this method patches the inferred attention class forward
|
Patches the attention classes in a transformer model with optimized LoRA implementations.
|
||||||
pass with optimized LoRA implementations.
|
|
||||||
|
|
||||||
It modifies the attention class to use optimized QKV and output projections. The
|
It modifies the attention class to use optimized QKV and output projections. The
|
||||||
original implementation is preserved and can be restored if needed.
|
original implementation is preserved and can be restored if needed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
model: A HuggingFace transformers model.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AssertionError: If the required code blocks are not found in the attention
|
AssertionError: If the required code blocks are not found in the attention
|
||||||
implementation.
|
implementation.
|
||||||
"""
|
"""
|
||||||
attention_cls = get_attention_cls_from_config(cfg)
|
# Find all attention modules in the model
|
||||||
|
attention_modules = [
|
||||||
|
module
|
||||||
|
for module in model.modules()
|
||||||
|
if "attention" in module.__class__.__name__.lower()
|
||||||
|
and hasattr(module, "forward")
|
||||||
|
]
|
||||||
|
|
||||||
# Check if already patched
|
if not attention_modules:
|
||||||
if hasattr(attention_cls, "_original_forward"):
|
LOG.warning("No attention modules found in model")
|
||||||
LOG.info(f"{attention_cls.__name__} already patched")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
attention_classes = {type(module) for module in attention_modules}
|
||||||
|
LOG.info(f"Found attention classes: {[cls.__name__ for cls in attention_classes]}")
|
||||||
|
|
||||||
|
for attention_cls in attention_classes:
|
||||||
|
# Skip if already patched
|
||||||
|
if hasattr(attention_cls, "_original_forward"):
|
||||||
|
LOG.info(f"{attention_cls.__name__} already patched")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get and store original forward implementation
|
||||||
self_attn_forward = inspect.getsource(attention_cls.forward)
|
self_attn_forward = inspect.getsource(attention_cls.forward)
|
||||||
attention_cls._original_forward = self_attn_forward
|
attention_cls._original_forward = self_attn_forward
|
||||||
|
|
||||||
|
# Remove indentation
|
||||||
self_attn_forward, _ = detab_code(self_attn_forward)
|
self_attn_forward, _ = detab_code(self_attn_forward)
|
||||||
|
|
||||||
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
|
# Verify required code blocks exist
|
||||||
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
|
assert (
|
||||||
|
ORIGINAL_QKV_CODE in self_attn_forward
|
||||||
|
), f"Original QKV code not found in {attention_cls.__name__}"
|
||||||
|
assert (
|
||||||
|
ORIGINAL_O_CODE in self_attn_forward
|
||||||
|
), f"Original O code not found in {attention_cls.__name__}"
|
||||||
|
|
||||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
|
# Replace code blocks
|
||||||
|
self_attn_forward = self_attn_forward.replace(
|
||||||
|
ORIGINAL_QKV_CODE, PATCHED_QKV_CODE
|
||||||
|
)
|
||||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
||||||
self_attn_forward = self_attn_forward.replace(
|
self_attn_forward = self_attn_forward.replace(
|
||||||
"def forward(",
|
"def forward(",
|
||||||
@@ -179,7 +157,7 @@ def patch_self_attn_lora(cfg: DictDefault):
|
|||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load necessary imports
|
# Import necessary symbols from the attention module
|
||||||
module_name = attention_cls.__module__
|
module_name = attention_cls.__module__
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
@@ -188,10 +166,13 @@ def patch_self_attn_lora(cfg: DictDefault):
|
|||||||
if item in self_attn_forward:
|
if item in self_attn_forward:
|
||||||
items_to_import.append(item)
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
if items_to_import:
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||||
globals(),
|
globals(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Execute the new implementation
|
||||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
|
||||||
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
|
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
|
||||||
|
|||||||
@@ -127,8 +127,6 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
**_kwargs,
|
**_kwargs,
|
||||||
):
|
):
|
||||||
if not optimizer:
|
|
||||||
optimizer = state.optimizer
|
|
||||||
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
||||||
checkpoint_folder = os.path.join(
|
checkpoint_folder = os.path.join(
|
||||||
args.output_dir,
|
args.output_dir,
|
||||||
|
|||||||
@@ -272,7 +272,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
dict(zip(feature_names, row))
|
dict(zip(feature_names, row))
|
||||||
)
|
)
|
||||||
for key, val in tokenized_prompt.items():
|
for key, val in tokenized_prompt.items():
|
||||||
res[key].append(val)
|
for i in range(0, len(val), self.sequence_len):
|
||||||
|
res[key].append(val[i : i + self.sequence_len])
|
||||||
|
|
||||||
# If there are no examples left, return an empty dictionary
|
# If there are no examples left, return an empty dictionary
|
||||||
if not res:
|
if not res:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union, Callable
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
@@ -53,7 +53,6 @@ class DataCollatorForSeq2Seq:
|
|||||||
label_pad_token_id: int = -100
|
label_pad_token_id: int = -100
|
||||||
position_pad_token_id: int = 0
|
position_pad_token_id: int = 0
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
sp_extract_fn: Optional[Callable] = None
|
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
labels = None
|
labels = None
|
||||||
@@ -122,10 +121,6 @@ class DataCollatorForSeq2Seq:
|
|||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
def seq_parallel_split(self, features):
|
|
||||||
if self.sp_extract_fn:
|
|
||||||
pass
|
|
||||||
return features
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
|
|||||||
@@ -342,7 +342,6 @@ class LoraConfig(BaseModel):
|
|||||||
peft_use_dora: Optional[bool] = None
|
peft_use_dora: Optional[bool] = None
|
||||||
peft_use_rslora: Optional[bool] = None
|
peft_use_rslora: Optional[bool] = None
|
||||||
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
||||||
peft_init_lora_weights: Optional[Union[bool, str]] = None
|
|
||||||
|
|
||||||
qlora_sharded_model_loading: Optional[bool] = Field(
|
qlora_sharded_model_loading: Optional[bool] = Field(
|
||||||
default=False,
|
default=False,
|
||||||
@@ -832,8 +831,6 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
eager_attention: Optional[bool] = None
|
eager_attention: Optional[bool] = None
|
||||||
|
|
||||||
sp_ulysses_degree: Optional[int] = None
|
|
||||||
|
|
||||||
unsloth_cross_entropy_loss: Optional[bool] = None
|
unsloth_cross_entropy_loss: Optional[bool] = None
|
||||||
unsloth_lora_mlp: Optional[bool] = None
|
unsloth_lora_mlp: Optional[bool] = None
|
||||||
unsloth_lora_qkv: Optional[bool] = None
|
unsloth_lora_qkv: Optional[bool] = None
|
||||||
|
|||||||
@@ -172,11 +172,10 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
min_input_len = np.min(get_dataset_lengths(dataset))
|
||||||
min_input_len = np.min(ds_lengths)
|
LOG.debug(f"min_input_len: {min_input_len}")
|
||||||
LOG.info(f"min_input_len: {min_input_len}")
|
max_input_len = np.max(get_dataset_lengths(dataset))
|
||||||
max_input_len = np.max(ds_lengths)
|
LOG.debug(f"max_input_len: {max_input_len}")
|
||||||
LOG.info(f"max_input_len: {max_input_len}")
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -86,12 +86,6 @@ def get_world_size():
|
|||||||
return int(os.getenv("WORLD_SIZE", "1"))
|
return int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
|
||||||
def get_rank():
|
|
||||||
if not is_distributed():
|
|
||||||
return 0
|
|
||||||
return dist.get_rank()
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def zero_only():
|
def zero_only():
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -439,11 +439,6 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_mistral_cross_entropy()
|
patch_mistral_cross_entropy()
|
||||||
|
|
||||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
|
||||||
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
|
||||||
|
|
||||||
patch_self_attn_lora(self.cfg)
|
|
||||||
|
|
||||||
def patch_attention(self) -> None:
|
def patch_attention(self) -> None:
|
||||||
if hasattr(self.model_config, "model_type"):
|
if hasattr(self.model_config, "model_type"):
|
||||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||||
@@ -1028,6 +1023,12 @@ class ModelLoader:
|
|||||||
integrate_rope_embeddings()
|
integrate_rope_embeddings()
|
||||||
|
|
||||||
def apply_lora_patch(self) -> None:
|
def apply_lora_patch(self) -> None:
|
||||||
|
"""Applies patching relevant to LoRA Triton kernels if enabled."""
|
||||||
|
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
|
||||||
|
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
||||||
|
|
||||||
|
patch_self_attn_lora(self.model)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.cfg.lora_mlp_kernel
|
self.cfg.lora_mlp_kernel
|
||||||
or self.cfg.lora_qkv_kernel
|
or self.cfg.lora_qkv_kernel
|
||||||
@@ -1181,6 +1182,7 @@ class ModelLoader:
|
|||||||
if self.cfg.adapter is not None:
|
if self.cfg.adapter is not None:
|
||||||
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
||||||
|
|
||||||
|
# TODO: Deprecate this.
|
||||||
self.apply_unsloth_lora_patch()
|
self.apply_unsloth_lora_patch()
|
||||||
self.apply_lora_patch()
|
self.apply_lora_patch()
|
||||||
|
|
||||||
@@ -1201,9 +1203,7 @@ def load_model(
|
|||||||
reference_model: bool = False,
|
reference_model: bool = False,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||||
"""
|
"""Load a model for a given configuration and tokenizer."""
|
||||||
Load a model for a given configuration and tokenizer.
|
|
||||||
"""
|
|
||||||
loader = ModelLoader(
|
loader = ModelLoader(
|
||||||
cfg,
|
cfg,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -1321,8 +1321,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
if loftq_bits:
|
if loftq_bits:
|
||||||
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
|
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
|
||||||
lora_config_kwargs["init_lora_weights"] = "loftq"
|
lora_config_kwargs["init_lora_weights"] = "loftq"
|
||||||
if cfg.peft_init_lora_weights:
|
|
||||||
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
|
|
||||||
if cfg.peft_use_dora:
|
if cfg.peft_use_dora:
|
||||||
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
||||||
LOG.info("Initializing LoRA weights using dora. This might take longer.")
|
LOG.info("Initializing LoRA weights using dora. This might take longer.")
|
||||||
|
|||||||
@@ -4,17 +4,13 @@ helper util to calculate dataset lengths
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_lengths(dataset, from_arrow=False):
|
def get_dataset_lengths(dataset):
|
||||||
if "length" in dataset.column_names:
|
if "length" in dataset.data.column_names:
|
||||||
lengths = np.array(dataset["length"])
|
lengths = np.array(dataset.data.column("length"))
|
||||||
elif "position_ids" in dataset.column_names:
|
elif "position_ids" in dataset.data.column_names:
|
||||||
position_ids = dataset["position_ids"]
|
position_ids = dataset.data.column("position_ids")
|
||||||
lengths = np.array([x[-1] + 1 for x in position_ids])
|
lengths = np.array([x[-1] + 1 for x in position_ids])
|
||||||
else:
|
else:
|
||||||
if from_arrow:
|
|
||||||
input_ids = dataset.data.column("input_ids")
|
input_ids = dataset.data.column("input_ids")
|
||||||
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
||||||
else:
|
|
||||||
input_ids = dataset["input_ids"]
|
|
||||||
lengths = np.array([len(seq) for seq in input_ids])
|
|
||||||
return lengths
|
return lengths
|
||||||
|
|||||||
@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (PoSE)",
|
desc="Add position_id column (PoSE)",
|
||||||
)
|
)
|
||||||
elif cfg.sample_packing or cfg.sp_ulysses_degree:
|
elif cfg.sample_packing:
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||||
|
|||||||
@@ -9,16 +9,14 @@ from transformers import AutoModelForCausalLM, LlamaForCausalLM
|
|||||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||||
|
|
||||||
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.kernels.lora import (
|
from axolotl.kernels.lora import (
|
||||||
apply_lora_mlp_geglu,
|
apply_lora_mlp_geglu,
|
||||||
apply_lora_mlp_swiglu,
|
apply_lora_mlp_swiglu,
|
||||||
apply_lora_o,
|
apply_lora_o,
|
||||||
apply_lora_qkv,
|
apply_lora_qkv,
|
||||||
)
|
)
|
||||||
from axolotl.monkeypatch.lora_kernels import (
|
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
|
||||||
apply_lora_kernel_patches,
|
|
||||||
patch_self_attn_lora,
|
|
||||||
)
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
MODEL_CONFIGS = [
|
MODEL_CONFIGS = [
|
||||||
@@ -65,15 +63,45 @@ def small_llama_model():
|
|||||||
return LlamaForCausalLM(LlamaConfig(**config))
|
return LlamaForCausalLM(LlamaConfig(**config))
|
||||||
|
|
||||||
|
|
||||||
def test_attention_patching_integration():
|
# pylint: disable=duplicate-code
|
||||||
"""Test attention patching in integration context."""
|
@pytest.fixture
|
||||||
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
|
def minimal_cfg():
|
||||||
|
"Config of real HuggingFace Hub model"
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"learning_rate": 0.000001,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.0,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"lora_mlp_kernel": True,
|
||||||
|
"lora_qkv_kernel": True,
|
||||||
|
"lora_o_kernel": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def test_attention_patching_integration(minimal_cfg):
|
||||||
|
"""Test attention patching in integration context."""
|
||||||
# Store the original implementation
|
# Store the original implementation
|
||||||
original_forward = getattr(LlamaAttention, "forward")
|
original_forward = getattr(LlamaAttention, "forward")
|
||||||
|
|
||||||
# Apply patch
|
# Load model
|
||||||
patch_self_attn_lora(cfg)
|
_, _ = load_model_and_tokenizer(cfg=minimal_cfg)
|
||||||
|
|
||||||
# Get the new forward method
|
# Get the new forward method
|
||||||
patched_forward = LlamaAttention.forward
|
patched_forward = LlamaAttention.forward
|
||||||
@@ -376,38 +404,10 @@ def test_model_architecture(model_config):
|
|||||||
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
def test_kernel_training_integration():
|
def test_kernel_training_integration(minimal_cfg):
|
||||||
"""Test model loading with kernel patches enabled."""
|
"""Test model loading with kernel patches enabled."""
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
|
||||||
|
|
||||||
# Create minimal config
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"learning_rate": 0.000001,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"lora_mlp_kernel": True,
|
|
||||||
"lora_qkv_kernel": True,
|
|
||||||
"lora_o_kernel": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model, _ = load_model_and_tokenizer(cfg=cfg)
|
model, _ = load_model_and_tokenizer(cfg=minimal_cfg)
|
||||||
|
|
||||||
# Verify correct activation function
|
# Verify correct activation function
|
||||||
layer = model.model.model.layers[0]
|
layer = model.model.model.layers[0]
|
||||||
|
|||||||
@@ -125,12 +125,6 @@ def fixture_llama3_tokenizer():
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True)
|
|
||||||
def fixture_smollm2_tokenizer():
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
||||||
def fixture_mistralv03_tokenizer():
|
def fixture_mistralv03_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
|||||||
@@ -1,61 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for loading DPO preference datasets with chatml formatting
|
|
||||||
"""
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
|
||||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="minimal_dpo_cfg")
|
|
||||||
def fixture_cfg():
|
|
||||||
return DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"rl": "dpo",
|
|
||||||
"learning_rate": 0.000001,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"sequence_len": 2048,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDPOChatml:
|
|
||||||
"""
|
|
||||||
Test loading DPO preference datasets with chatml formatting
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_default(self, minimal_dpo_cfg):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "argilla/distilabel-intel-orca-dpo-pairs",
|
|
||||||
"type": "chatml",
|
|
||||||
"split": "train[:1%]",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
| minimal_dpo_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
# test that dpo.load works
|
|
||||||
load_dpo("chatml", cfg)
|
|
||||||
# now actually load the datasets with the strategy
|
|
||||||
train_ds, _ = load_prepare_preference_datasets(cfg)
|
|
||||||
assert train_ds[0]["prompt"].startswith("<|im_start|>")
|
|
||||||
assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n")
|
|
||||||
assert "chosen" in train_ds[0]
|
|
||||||
assert "rejected" in train_ds[0]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -7,7 +7,6 @@ from transformers import AutoTokenizer
|
|||||||
from axolotl.datasets import TokenizedPromptDataset
|
from axolotl.datasets import TokenizedPromptDataset
|
||||||
from axolotl.prompt_strategies.completion import load
|
from axolotl.prompt_strategies.completion import load
|
||||||
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
|
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
from axolotl.utils.data.utils import drop_long_seq_in_dataset
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
@@ -19,6 +18,11 @@ def fixture_tokenizer():
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="max_seq_length")
|
||||||
|
def fixture_max_seq_length():
|
||||||
|
return 4096
|
||||||
|
|
||||||
|
|
||||||
class TestBatchedSamplerPacking:
|
class TestBatchedSamplerPacking:
|
||||||
"""
|
"""
|
||||||
Test class for packing streaming dataset sequences
|
Test class for packing streaming dataset sequences
|
||||||
@@ -33,7 +37,6 @@ class TestBatchedSamplerPacking:
|
|||||||
(2, 2),
|
(2, 2),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("max_seq_length", [4096, 512])
|
|
||||||
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
||||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||||
|
|
||||||
@@ -59,9 +62,6 @@ class TestBatchedSamplerPacking:
|
|||||||
dataset,
|
dataset,
|
||||||
)
|
)
|
||||||
train_dataset = concatenate_datasets([dataset_wrapper])
|
train_dataset = concatenate_datasets([dataset_wrapper])
|
||||||
|
|
||||||
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg)
|
|
||||||
|
|
||||||
lengths = get_dataset_lengths(train_dataset)
|
lengths = get_dataset_lengths(train_dataset)
|
||||||
batch_sampler = MultipackBatchSampler(
|
batch_sampler = MultipackBatchSampler(
|
||||||
sampler=RandomSampler(train_dataset),
|
sampler=RandomSampler(train_dataset),
|
||||||
@@ -90,7 +90,7 @@ class TestBatchedSamplerPacking:
|
|||||||
batch_idxs.extend(pack)
|
batch_idxs.extend(pack)
|
||||||
|
|
||||||
for batch in loader:
|
for batch in loader:
|
||||||
assert batch["input_ids"].numel() <= batch_size * max_seq_length
|
assert len(batch["input_ids"]) <= batch_size * max_seq_length
|
||||||
assert batch["input_ids"].shape[1] == max_seq_length
|
assert batch["input_ids"].shape[1] == max_seq_length
|
||||||
|
|
||||||
original_idxs = set(range(len(train_dataset)))
|
original_idxs = set(range(len(train_dataset)))
|
||||||
|
|||||||
Reference in New Issue
Block a user