sampler / dataloader refactor
This commit is contained in:
@@ -7,7 +7,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
@@ -435,20 +435,25 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _create_sp_sampler(self, dataset, shuffle=True, is_eval=False):
|
||||||
|
"""Create a sampler for sequence parallelism"""
|
||||||
|
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||||
|
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||||
|
|
||||||
|
return torch.utils.data.distributed.DistributedSampler(
|
||||||
|
dataset,
|
||||||
|
num_replicas=num_sp_groups,
|
||||||
|
rank=sp_group_id,
|
||||||
|
seed=self.args.seed if shuffle else None,
|
||||||
|
shuffle=shuffle,
|
||||||
|
drop_last=not is_eval,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_train_sampler(self) -> torch.utils.data.Sampler | None:
|
||||||
# Handle sequence parallelism
|
# Handle sequence parallelism
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
base_sampler = self._create_sp_sampler(
|
||||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
self.train_dataset, shuffle=not self.args.curriculum_sampling
|
||||||
|
|
||||||
# Create base sampler for SP groups
|
|
||||||
base_sampler = torch.utils.data.distributed.DistributedSampler(
|
|
||||||
self.train_dataset,
|
|
||||||
num_replicas=num_sp_groups,
|
|
||||||
rank=sp_group_id,
|
|
||||||
seed=self.args.seed if not self.args.curriculum_sampling else None,
|
|
||||||
shuffle=not self.args.curriculum_sampling,
|
|
||||||
drop_last=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply multipack wrapper if needed
|
# Apply multipack wrapper if needed
|
||||||
@@ -458,16 +463,15 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
dataset=self.train_dataset,
|
dataset=self.train_dataset,
|
||||||
group_size=self.args.sample_packing_group_size,
|
group_size=self.args.sample_packing_group_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
return base_sampler
|
return base_sampler
|
||||||
|
|
||||||
|
# Regular training sampler logic
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
base_sampler = (
|
base_sampler = (
|
||||||
SequentialSampler(self.train_dataset)
|
SequentialSampler(self.train_dataset)
|
||||||
if self.args.curriculum_sampling
|
if self.args.curriculum_sampling
|
||||||
else RandomSampler(self.train_dataset)
|
else RandomSampler(self.train_dataset)
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._create_multipack_sampler(
|
return self._create_multipack_sampler(
|
||||||
base_sampler=base_sampler,
|
base_sampler=base_sampler,
|
||||||
dataset=self.train_dataset,
|
dataset=self.train_dataset,
|
||||||
@@ -480,94 +484,76 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
def _get_eval_sampler(
|
def _get_eval_sampler(
|
||||||
self, eval_dataset: Optional[Dataset] = None
|
self, eval_dataset: Dataset | None = None
|
||||||
) -> Optional[torch.utils.data.Sampler]:
|
) -> torch.utils.data.Sampler | None:
|
||||||
"""Get evaluation sampler"""
|
"""Get evaluation sampler"""
|
||||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
|
||||||
# Handle sequence parallelism
|
# Get the appropriate group size for sample packing
|
||||||
if self.args.sequence_parallel_degree > 1:
|
def get_pack_group_size():
|
||||||
# Create sampler for SP groups
|
return (
|
||||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
|
||||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
|
||||||
|
|
||||||
# Create distributed sampler for the SP group
|
|
||||||
base_sampler = torch.utils.data.distributed.DistributedSampler(
|
|
||||||
eval_dataset,
|
|
||||||
num_replicas=num_sp_groups,
|
|
||||||
rank=sp_group_id,
|
|
||||||
shuffle=False,
|
|
||||||
drop_last=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
|
||||||
group_size = (
|
|
||||||
self.args.eval_packing_group_size
|
|
||||||
if hasattr(self.args, "eval_packing_group_size")
|
|
||||||
else self.args.sample_packing_group_size
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._create_multipack_sampler(
|
|
||||||
base_sampler=base_sampler,
|
|
||||||
dataset=eval_dataset,
|
|
||||||
group_size=group_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
return base_sampler
|
|
||||||
|
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
|
||||||
base_sampler = SequentialSampler(eval_dataset)
|
|
||||||
group_size = (
|
|
||||||
self.args.eval_packing_group_size
|
self.args.eval_packing_group_size
|
||||||
if hasattr(self.args, "eval_packing_group_size")
|
if hasattr(self.args, "eval_packing_group_size")
|
||||||
else self.args.sample_packing_group_size
|
else self.args.sample_packing_group_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Handle sequence parallelism
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
base_sampler = self._create_sp_sampler(
|
||||||
|
eval_dataset, shuffle=False, is_eval=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||||
|
return self._create_multipack_sampler(
|
||||||
|
base_sampler=base_sampler,
|
||||||
|
dataset=eval_dataset,
|
||||||
|
group_size=get_pack_group_size(),
|
||||||
|
)
|
||||||
|
return base_sampler
|
||||||
|
|
||||||
|
# Regular evaluation sampler logic
|
||||||
|
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||||
|
base_sampler = SequentialSampler(eval_dataset)
|
||||||
return self._create_multipack_sampler(
|
return self._create_multipack_sampler(
|
||||||
base_sampler=base_sampler,
|
base_sampler=base_sampler,
|
||||||
dataset=eval_dataset,
|
dataset=eval_dataset,
|
||||||
group_size=group_size,
|
group_size=get_pack_group_size(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
|
||||||
"""Get dataloader for training"""
|
"""Create common dataloader parameters for train or eval."""
|
||||||
train_dataset = self.train_dataset
|
batch_size = custom_batch_size or (
|
||||||
data_collator = self.data_collator
|
self.args.eval_batch_size if is_eval else self._train_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
# Handle dataset preprocessing
|
params = {
|
||||||
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
"batch_size": batch_size,
|
||||||
if (
|
"collate_fn": self.data_collator,
|
||||||
self.args.sample_packing
|
|
||||||
and not self.args.pretraining
|
|
||||||
and "length" in train_dataset.features
|
|
||||||
):
|
|
||||||
train_dataset = train_dataset.remove_columns(["length"])
|
|
||||||
if not (self.args.sample_packing and not self.args.pretraining):
|
|
||||||
train_dataset = self._remove_unused_columns(
|
|
||||||
train_dataset, description="training"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
data_collator = self._get_collator_with_removed_columns(
|
|
||||||
data_collator, description="training"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build common dataloader parameters
|
|
||||||
dataloader_params = {
|
|
||||||
"batch_size": self._train_batch_size,
|
|
||||||
"collate_fn": data_collator,
|
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
"persistent_workers": self.args.dataloader_persistent_workers,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Add persistent workers only for training
|
||||||
|
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
|
||||||
|
params["persistent_workers"] = self.args.dataloader_persistent_workers
|
||||||
|
|
||||||
|
# Add prefetch factor if specified
|
||||||
if self.args.dataloader_prefetch_factor:
|
if self.args.dataloader_prefetch_factor:
|
||||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||||
|
|
||||||
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
return params
|
||||||
sampler = self._get_train_sampler()
|
|
||||||
|
|
||||||
|
def _prepare_dataloader(
|
||||||
|
self, dataset, sampler, is_eval=False, custom_batch_size=None
|
||||||
|
):
|
||||||
|
"""Prepare a dataloader with the given dataset and sampler."""
|
||||||
|
# Get base parameters
|
||||||
|
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
|
||||||
|
|
||||||
|
# Add sampler configuration
|
||||||
|
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||||
if isinstance(sampler, BatchSampler):
|
if isinstance(sampler, BatchSampler):
|
||||||
# batch_size and batch_sampler are mutually exclusive
|
# batch_size and batch_sampler are mutually exclusive
|
||||||
dataloader_params["batch_sampler"] = sampler
|
dataloader_params["batch_sampler"] = sampler
|
||||||
@@ -576,114 +562,110 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
dataloader_params["sampler"] = sampler
|
dataloader_params["sampler"] = sampler
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
dataloader_params["worker_init_fn"] = seed_worker
|
if not is_eval:
|
||||||
|
dataloader_params["worker_init_fn"] = seed_worker
|
||||||
|
|
||||||
dataloader = DataLoader(train_dataset, **dataloader_params)
|
# Create the dataloader
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
dataloader = DataLoader(dataset, **dataloader_params)
|
||||||
|
|
||||||
|
if self.args.sample_packing and (
|
||||||
|
(not is_eval and not self.args.pretraining)
|
||||||
|
or (is_eval and self.args.eval_sample_packing is not False)
|
||||||
|
):
|
||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
|
|
||||||
# Don't prepare dataloader for sequence parallelism
|
# Return unprepared dataloader if using sequence parallelism
|
||||||
# We use a distributed sampler in this case
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
|
# Otherwise prepare with accelerator
|
||||||
return self.accelerator.prepare_data_loader(dataloader)
|
return self.accelerator.prepare_data_loader(dataloader)
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
|
"""Get dataloader for training"""
|
||||||
|
train_dataset = self.train_dataset
|
||||||
|
data_collator = self.data_collator # type: ignore
|
||||||
|
|
||||||
|
# Handle dataset preprocessing
|
||||||
|
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||||
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
|
train_dataset = train_dataset.remove_columns(["length"])
|
||||||
|
if not (self.args.sample_packing and not self.args.pretraining):
|
||||||
|
train_dataset = self._remove_unused_columns(
|
||||||
|
train_dataset, description="training"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||||
|
data_collator,
|
||||||
|
description="training",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get sampler and create dataloader
|
||||||
|
sampler = self._get_train_sampler()
|
||||||
|
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
|
||||||
|
|
||||||
|
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
|
||||||
"""Get dataloader for evaluation"""
|
"""Get dataloader for evaluation"""
|
||||||
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
|
||||||
|
# Handle special case: sample packing is enabled but eval_sample_packing is False
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.eval_data_collator
|
self.eval_data_collator
|
||||||
)
|
)
|
||||||
if eval_dataset and "length" in eval_dataset.features:
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
|
||||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.train_data_collator
|
self.train_data_collator
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataloader
|
return dataloader
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
|
||||||
eval_dataset = (
|
|
||||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
|
||||||
)
|
|
||||||
|
|
||||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
# Handle sample packing or sequence parallelism
|
||||||
|
if (
|
||||||
# Only remove length column if it exists
|
self.args.sample_packing
|
||||||
if "length" in eval_dataset.features:
|
and self.args.eval_sample_packing is not False
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
or self.args.sequence_parallel_degree > 1
|
||||||
|
):
|
||||||
data_collator = self.data_collator
|
# Get appropriate data collator
|
||||||
dataloader_params = {
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
"batch_size": self.args.eval_batch_size,
|
|
||||||
"collate_fn": data_collator,
|
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.args.dataloader_prefetch_factor:
|
|
||||||
dataloader_params["prefetch_factor"] = (
|
|
||||||
self.args.dataloader_prefetch_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(eval_sampler, BatchSampler):
|
|
||||||
dataloader_params["batch_sampler"] = eval_sampler
|
|
||||||
del dataloader_params["batch_size"]
|
|
||||||
else:
|
|
||||||
dataloader_params["sampler"] = eval_sampler
|
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
|
||||||
|
|
||||||
self.accelerator.even_batches = False
|
|
||||||
dataloader = DataLoader(eval_dataset, **dataloader_params)
|
|
||||||
|
|
||||||
# Don't prepare dataloader for sequence parallelism
|
|
||||||
# We use a distributed sampler in this case
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
return dataloader
|
|
||||||
|
|
||||||
return self.accelerator.prepare_data_loader(dataloader)
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
eval_dataset = (
|
|
||||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
|
||||||
)
|
|
||||||
data_collator = (
|
|
||||||
self.eval_data_collator
|
self.eval_data_collator
|
||||||
if self.eval_data_collator
|
if hasattr(self, "eval_data_collator") and self.eval_data_collator
|
||||||
else self.data_collator
|
else self.data_collator
|
||||||
)
|
)
|
||||||
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
|
|
||||||
# Handle dataset preprocessing as in the parent implementation
|
# Handle dataset preprocessing for SP
|
||||||
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
if self.args.sequence_parallel_degree > 1:
|
||||||
eval_dataset = self._remove_unused_columns(
|
if is_datasets_available() and isinstance(
|
||||||
eval_dataset, description="evaluation"
|
eval_dataset, datasets.Dataset
|
||||||
)
|
):
|
||||||
else:
|
eval_dataset = self._remove_unused_columns(
|
||||||
data_collator = self._get_collator_with_removed_columns(
|
eval_dataset, description="evaluation"
|
||||||
data_collator, description="evaluation"
|
)
|
||||||
)
|
else:
|
||||||
|
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.data_collator, description="evaluation"
|
||||||
|
)
|
||||||
|
|
||||||
# Build dataloader parameters
|
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
|
||||||
dataloader_params = {
|
batch_size = (
|
||||||
"batch_size": self.args.per_device_eval_batch_size,
|
self.args.eval_batch_size
|
||||||
"collate_fn": data_collator,
|
if self.args.sample_packing
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
else self.args.per_device_eval_batch_size
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
)
|
||||||
}
|
sampler = self._get_eval_sampler(eval_dataset)
|
||||||
|
dataloader = self._prepare_dataloader(
|
||||||
|
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
return dataloader
|
||||||
sampler = self._get_eval_sampler(eval_dataset)
|
|
||||||
dataloader_params["sampler"] = sampler
|
|
||||||
|
|
||||||
# Don't prepare dataloader for sequence parallelism
|
|
||||||
# We use a distributed sampler in this case
|
|
||||||
return DataLoader(eval_dataset, **dataloader_params)
|
|
||||||
|
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
def _get_bench_sampler(
|
def _get_bench_sampler(
|
||||||
self, bench_dataset: Dataset
|
self, bench_dataset: Dataset
|
||||||
) -> Optional[torch.utils.data.Sampler]:
|
) -> torch.utils.data.Sampler | None:
|
||||||
if self.args.world_size <= 1:
|
if self.args.world_size <= 1:
|
||||||
return SequentialSampler(bench_dataset)
|
return SequentialSampler(bench_dataset)
|
||||||
return None
|
return None
|
||||||
@@ -920,15 +902,13 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
Log `logs` on the various objects watching training, including stored metrics.
|
Log `logs` on the various objects watching training, including stored metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logs (`Dict[str, float]`):
|
logs: The values to log.
|
||||||
The values to log.
|
start_time: The start of training.
|
||||||
start_time (`Optional[float]`):
|
|
||||||
The start of training.
|
|
||||||
"""
|
"""
|
||||||
# logs either has 'loss' or 'eval_loss'
|
# logs either has 'loss' or 'eval_loss'
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
@@ -940,7 +920,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
return super().log(logs, start_time)
|
return super().log(logs, start_time)
|
||||||
|
|
||||||
def store_metrics(
|
def store_metrics(
|
||||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||||
) -> None:
|
) -> None:
|
||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
self._stored_metrics[train_eval][key].append(value)
|
self._stored_metrics[train_eval][key].append(value)
|
||||||
@@ -1043,7 +1023,7 @@ class ReLoRATrainer(AxolotlTrainer):
|
|||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
self,
|
self,
|
||||||
num_training_steps: int,
|
num_training_steps: int,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: torch.optim.Optimizer | None = None,
|
||||||
):
|
):
|
||||||
optimizer = self.optimizer if optimizer is None else optimizer
|
optimizer = self.optimizer if optimizer is None else optimizer
|
||||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|||||||
Reference in New Issue
Block a user