add back in reinit_weights (clobbered?); masking / pretrain fixes
This commit is contained in:
@@ -34,7 +34,7 @@ lr_scheduler: cosine
|
|||||||
learning_rate: 3e-4
|
learning_rate: 3e-4
|
||||||
|
|
||||||
bf16: auto
|
bf16: auto
|
||||||
tf32: false
|
tf32: true
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
@@ -51,8 +51,8 @@ eval_steps: 1000
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|end_of_text|>"
|
pad_token: "<|end_of_text|>"
|
||||||
|
|
||||||
wandb_project: diffusion-plugin
|
wandb_project:
|
||||||
wandb_entity: axolotl-ai
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|||||||
@@ -82,7 +82,9 @@ class AxolotlTrainer(
|
|||||||
super().__init__(*_args, **kwargs)
|
super().__init__(*_args, **kwargs)
|
||||||
|
|
||||||
self.train_data_collator = self.data_collator
|
self.train_data_collator = self.data_collator
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(
|
||||||
|
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
|
||||||
|
)
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
@@ -573,9 +575,26 @@ class AxolotlTrainer(
|
|||||||
"""
|
"""
|
||||||
# logs either has 'loss' or 'eval_loss'
|
# logs either has 'loss' or 'eval_loss'
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
# Add averaged stored metrics to logs
|
|
||||||
for key, metrics in self._stored_metrics[train_eval].items():
|
# Add reduced stored metrics to logs
|
||||||
logs[key] = torch.tensor(metrics).mean().item()
|
for key, metric_data in self._stored_metrics[train_eval].items():
|
||||||
|
values = torch.tensor(metric_data["values"])
|
||||||
|
reduction_type = metric_data["reduction"]
|
||||||
|
|
||||||
|
if reduction_type == "mean":
|
||||||
|
logs[key] = values.mean().item()
|
||||||
|
elif reduction_type == "min":
|
||||||
|
logs[key] = values.min().item()
|
||||||
|
elif reduction_type == "max":
|
||||||
|
logs[key] = values.max().item()
|
||||||
|
elif reduction_type == "sum":
|
||||||
|
logs[key] = values.sum().item()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Metric reduction must be one of [mean, min, max]"
|
||||||
|
)
|
||||||
|
|
||||||
|
logs[key] = round(logs[key], 4)
|
||||||
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
# Add memory usage
|
# Add memory usage
|
||||||
@@ -592,10 +611,27 @@ class AxolotlTrainer(
|
|||||||
return super().log(logs, start_time)
|
return super().log(logs, start_time)
|
||||||
|
|
||||||
def store_metrics(
|
def store_metrics(
|
||||||
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
self,
|
||||||
|
metrics: dict[str, float] | dict[str, tuple[int | float, str]],
|
||||||
|
train_eval: Literal["train", "eval"] = "train",
|
||||||
|
reduction: Literal["mean", "min", "max", "sum"] = "mean",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Store metrics with specified reduction type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metrics: Dictionary of metric names to values, or metric names to (value,
|
||||||
|
reduction_type) tuples.
|
||||||
|
train_eval: Whether this is for training or evaluation.
|
||||||
|
"""
|
||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
self._stored_metrics[train_eval][key].append(value)
|
if isinstance(value, tuple):
|
||||||
|
metric_value, metric_reduction = value
|
||||||
|
else:
|
||||||
|
metric_value, metric_reduction = value, reduction
|
||||||
|
|
||||||
|
self._stored_metrics[train_eval][key]["values"].append(metric_value)
|
||||||
|
self._stored_metrics[train_eval][key]["reduction"] = metric_reduction
|
||||||
|
|
||||||
def _save_checkpoint(self, model, trial, **kwargs):
|
def _save_checkpoint(self, model, trial, **kwargs):
|
||||||
# make sure the checkpoint dir exists, since trainer is flakey
|
# make sure the checkpoint dir exists, since trainer is flakey
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Custom trainer for diffusion LM training."""
|
"""Custom trainer for diffusion LM training."""
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -16,14 +18,35 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.config = None
|
self._config = None
|
||||||
self._special_token_ids = None
|
self._special_token_ids = None
|
||||||
|
|
||||||
def set_config(self, config: DictDefault):
|
def set_config(self, config: DictDefault):
|
||||||
"""Set config for diffusion training."""
|
"""Set config for diffusion training."""
|
||||||
self.config = config
|
self._config = config
|
||||||
self._cache_special_token_ids()
|
self._cache_special_token_ids()
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: dict[str, torch.Tensor],
|
||||||
|
return_outputs: bool = False,
|
||||||
|
num_items_in_batch: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||||
|
"""Override compute_loss to use diffusion loss."""
|
||||||
|
input_ids = inputs.get("input_ids")
|
||||||
|
attention_mask = inputs.get("attention_mask")
|
||||||
|
|
||||||
|
if input_ids is None:
|
||||||
|
raise ValueError("input_ids is required for diffusion training")
|
||||||
|
|
||||||
|
loss, outputs = self._compute_diffusion_loss(model, input_ids, attention_mask)
|
||||||
|
|
||||||
|
if return_outputs:
|
||||||
|
return loss, outputs
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
def _cache_special_token_ids(self):
|
def _cache_special_token_ids(self):
|
||||||
"""Cache special token IDs to avoid repeated tokenizer access."""
|
"""Cache special token IDs to avoid repeated tokenizer access."""
|
||||||
if self.processing_class is None:
|
if self.processing_class is None:
|
||||||
@@ -42,7 +65,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
|||||||
|
|
||||||
self._special_token_ids = special_tokens
|
self._special_token_ids = special_tokens
|
||||||
|
|
||||||
def forward_process(
|
def _forward_process(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
attention_mask: torch.Tensor | None = None,
|
attention_mask: torch.Tensor | None = None,
|
||||||
@@ -90,14 +113,14 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
|||||||
masked_indices = masked_indices & attention_mask.bool()
|
masked_indices = masked_indices & attention_mask.bool()
|
||||||
|
|
||||||
# Get mask token ID from config
|
# Get mask token ID from config
|
||||||
mask_token_id = self.config.mask_token_id
|
mask_token_id = self._config.mask_token_id
|
||||||
|
|
||||||
# Create masked input using configured mask token
|
# Create masked input using configured mask token
|
||||||
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
|
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
|
||||||
|
|
||||||
return noisy_batch, masked_indices, p_mask
|
return noisy_batch, masked_indices, p_mask
|
||||||
|
|
||||||
def create_bidirectional_attention_mask(
|
def _create_bidirectional_attention_mask(
|
||||||
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
|
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -115,7 +138,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
|||||||
batch_size, seq_len = input_ids.shape
|
batch_size, seq_len = input_ids.shape
|
||||||
device = input_ids.device
|
device = input_ids.device
|
||||||
|
|
||||||
if attention_mask is None or not self.config.sample_packing:
|
if attention_mask is None or not self._config.sample_packing:
|
||||||
# Simple case: no attention mask, allow all-to-all attention
|
# Simple case: no attention mask, allow all-to-all attention
|
||||||
return torch.ones(
|
return torch.ones(
|
||||||
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
|
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
|
||||||
@@ -133,12 +156,12 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
|||||||
|
|
||||||
return bidirectional_mask
|
return bidirectional_mask
|
||||||
|
|
||||||
def compute_diffusion_loss(
|
def _compute_diffusion_loss(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
attention_mask: torch.Tensor | None = None,
|
attention_mask: torch.Tensor | None = None,
|
||||||
) -> tuple[torch.Tensor, dict[str, float]]:
|
) -> tuple[torch.Tensor, torch.Tensor | Any]:
|
||||||
"""
|
"""
|
||||||
Compute diffusion loss.
|
Compute diffusion loss.
|
||||||
|
|
||||||
@@ -152,12 +175,12 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
|||||||
metrics: Dictionary of metrics.
|
metrics: Dictionary of metrics.
|
||||||
"""
|
"""
|
||||||
# Apply forward process
|
# Apply forward process
|
||||||
noisy_batch, masked_indices, p_mask = self.forward_process(
|
noisy_batch, masked_indices, p_mask = self._forward_process(
|
||||||
input_ids, attention_mask, self.config.eps
|
input_ids, attention_mask, self._config.eps
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create bidirectional attention mask
|
# Create bidirectional attention mask
|
||||||
bidirectional_mask = self.create_bidirectional_attention_mask(
|
bidirectional_mask = self._create_bidirectional_attention_mask(
|
||||||
input_ids, attention_mask
|
input_ids, attention_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -187,7 +210,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Apply importance weighting if enabled
|
# Apply importance weighting if enabled
|
||||||
if self.config.importance_weighting:
|
if self._config.importance_weighting:
|
||||||
masked_p_mask = masked_p_mask.float()
|
masked_p_mask = masked_p_mask.float()
|
||||||
weighted_loss = token_loss / masked_p_mask
|
weighted_loss = token_loss / masked_p_mask
|
||||||
else:
|
else:
|
||||||
@@ -211,40 +234,15 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
|||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
"accuracy": accuracy.item(),
|
"accuracy": accuracy.item(),
|
||||||
"mask_ratio": masked_indices.float().mean().item(),
|
"mask_ratio": masked_indices.float().mean().item(),
|
||||||
"num_masked_tokens": masked_indices.sum().item(),
|
"num_masked_tokens": (masked_indices.sum().item(), "sum"),
|
||||||
"avg_p_mask": p_mask[masked_indices].mean().item(),
|
"avg_p_mask": p_mask[masked_indices].mean().item(),
|
||||||
"ce_loss": ce_loss.item(),
|
"ce_loss": ce_loss.item(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.config.importance_weighting:
|
if self._config.importance_weighting:
|
||||||
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
|
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
|
||||||
|
|
||||||
return loss, metrics
|
train_eval: Literal["train", "eval"] = "train" if model.training else "eval"
|
||||||
|
self.store_metrics(metrics, train_eval=train_eval)
|
||||||
|
|
||||||
def compute_loss(
|
return loss, outputs
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
inputs: dict[str, torch.Tensor],
|
|
||||||
return_outputs: bool = False,
|
|
||||||
num_items_in_batch: torch.Tensor | None = None,
|
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
|
||||||
"""Override compute_loss to use diffusion loss."""
|
|
||||||
input_ids = inputs.get("input_ids")
|
|
||||||
attention_mask = inputs.get("attention_mask")
|
|
||||||
|
|
||||||
if input_ids is None:
|
|
||||||
raise ValueError("input_ids is required for diffusion training")
|
|
||||||
|
|
||||||
loss, metrics = self.compute_diffusion_loss(model, input_ids, attention_mask)
|
|
||||||
|
|
||||||
# # Log metrics
|
|
||||||
# if self.state.is_local_process_zero:
|
|
||||||
# for key, value in metrics.items():
|
|
||||||
# self.log({f"train/diffusion_{key}": value})
|
|
||||||
|
|
||||||
if return_outputs:
|
|
||||||
# TODO: compute outputs (?)
|
|
||||||
outputs = [loss]
|
|
||||||
return (loss, outputs)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|||||||
@@ -667,6 +667,23 @@ class ModelLoader:
|
|||||||
|
|
||||||
return hf_ds_cfg
|
return hf_ds_cfg
|
||||||
|
|
||||||
|
def _load_model_from_config(self) -> PreTrainedModel:
|
||||||
|
"""Load model with random initialization using from_config."""
|
||||||
|
if self.auto_model_loader in [AutoModelForCausalLM, AutoModelForVision2Seq]:
|
||||||
|
return self.auto_model_loader.from_config(config=self.model_config)
|
||||||
|
return self.auto_model_loader(config=self.model_config)
|
||||||
|
|
||||||
|
def _load_model_from_pretrained(self, model_loader_class=None) -> PreTrainedModel:
|
||||||
|
"""Load model from pretrained weights."""
|
||||||
|
loader = model_loader_class or self.auto_model_loader
|
||||||
|
kwargs = {
|
||||||
|
**self.model_kwargs,
|
||||||
|
"config": self.model_config,
|
||||||
|
"trust_remote_code": self.cfg.trust_remote_code or False,
|
||||||
|
**self.model_kwargs,
|
||||||
|
}
|
||||||
|
return loader.from_pretrained(self.base_model, **kwargs)
|
||||||
|
|
||||||
def _build_model(self) -> bool:
|
def _build_model(self) -> bool:
|
||||||
"""Load model, with load strategy depending on config."""
|
"""Load model, with load strategy depending on config."""
|
||||||
skip_move_to_device = False
|
skip_move_to_device = False
|
||||||
@@ -681,7 +698,8 @@ class ModelLoader:
|
|||||||
if self.is_fsdp_enabled:
|
if self.is_fsdp_enabled:
|
||||||
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
# Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map
|
# Don't delete device_map for QLoRA + FSDP - it was set correctly in
|
||||||
|
# _set_device_map
|
||||||
if (
|
if (
|
||||||
"device_map" in self.model_kwargs
|
"device_map" in self.model_kwargs
|
||||||
and not self.is_qlora_and_fsdp_enabled
|
and not self.is_qlora_and_fsdp_enabled
|
||||||
@@ -710,6 +728,11 @@ class ModelLoader:
|
|||||||
or self.cfg.qlora_sharded_model_loading
|
or self.cfg.qlora_sharded_model_loading
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
|
if self.cfg.reinit_weights:
|
||||||
|
LOG.warning(
|
||||||
|
"reinit_weights is not supported with sharded quantized loading. "
|
||||||
|
"Loading from pretrained weights instead."
|
||||||
|
)
|
||||||
quant_storage = self.cfg.torch_dtype
|
quant_storage = self.cfg.torch_dtype
|
||||||
quantization_config = getattr(
|
quantization_config = getattr(
|
||||||
self.model_config, "quantization_config", None
|
self.model_config, "quantization_config", None
|
||||||
@@ -725,33 +748,12 @@ class ModelLoader:
|
|||||||
quantization_config=quantization_config,
|
quantization_config=quantization_config,
|
||||||
)
|
)
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
elif (
|
|
||||||
self.model_config.model_type in ["llama", "llama4"]
|
|
||||||
and not self.cfg.trust_remote_code
|
|
||||||
and not self.cfg.gptq
|
|
||||||
):
|
|
||||||
# Please don't remove underscore binding without reading the fn docstring.
|
|
||||||
_ = self._configure_zero3_memory_efficient_loading()
|
|
||||||
|
|
||||||
# Load model with random initialization if specified
|
|
||||||
if self.cfg.random_init_weights:
|
|
||||||
# AutoModel classes support the from_config method
|
|
||||||
if self.auto_model_loader in [
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoModelForVision2Seq,
|
|
||||||
]:
|
|
||||||
self.model = self.auto_model_loader.from_config(
|
|
||||||
config=self.model_config,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.model = self.auto_model_loader(config=self.model_config)
|
|
||||||
else:
|
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
|
||||||
self.base_model,
|
|
||||||
config=self.model_config,
|
|
||||||
**self.model_kwargs,
|
|
||||||
)
|
|
||||||
elif self.model_type == "MambaLMHeadModel":
|
elif self.model_type == "MambaLMHeadModel":
|
||||||
|
if self.cfg.reinit_weights:
|
||||||
|
LOG.warning(
|
||||||
|
"reinit_weights is not supported with MambaLMHeadModel. "
|
||||||
|
"Loading from pretrained weights instead."
|
||||||
|
)
|
||||||
# FIXME this is janky at best and hacked together to make it work
|
# FIXME this is janky at best and hacked together to make it work
|
||||||
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
||||||
|
|
||||||
@@ -764,41 +766,27 @@ class ModelLoader:
|
|||||||
self.base_model,
|
self.base_model,
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
)
|
)
|
||||||
elif (
|
|
||||||
self.model_type
|
|
||||||
and self.model_type != "AutoModelForCausalLM"
|
|
||||||
and not self.cfg.trust_remote_code
|
|
||||||
):
|
|
||||||
if self.cfg.gptq:
|
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
|
||||||
self.base_model,
|
|
||||||
config=self.model_config,
|
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
|
||||||
**self.model_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.model = getattr(transformers, self.model_type).from_pretrained(
|
|
||||||
self.base_model,
|
|
||||||
config=self.model_config,
|
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
|
||||||
**self.model_kwargs,
|
|
||||||
)
|
|
||||||
elif self.cfg.gptq:
|
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
|
||||||
self.base_model,
|
|
||||||
config=self.model_config,
|
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
|
||||||
**self.model_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Please don't remove underscore binding without reading the fn docstring.
|
# Please don't remove underscore binding without reading the fn docstring
|
||||||
_ = self._configure_zero3_memory_efficient_loading()
|
_ = self._configure_zero3_memory_efficient_loading()
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
|
||||||
self.base_model,
|
if (
|
||||||
config=self.model_config,
|
self.model_type
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
and self.model_type != "AutoModelForCausalLM"
|
||||||
**self.model_kwargs,
|
and not self.cfg.trust_remote_code
|
||||||
)
|
and not self.cfg.gptq
|
||||||
|
):
|
||||||
|
# Use model type from transformers
|
||||||
|
model_loader_class = getattr(transformers, self.model_type)
|
||||||
|
else:
|
||||||
|
# Use auto model loader (handles gptq and default cases)
|
||||||
|
model_loader_class = self.auto_model_loader
|
||||||
|
|
||||||
|
if self.cfg.reinit_weights:
|
||||||
|
self.model = self._load_model_from_config()
|
||||||
|
else:
|
||||||
|
self.model = self._load_model_from_pretrained(model_loader_class)
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
|
|
||||||
|
|||||||
@@ -109,6 +109,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs"
|
"description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
reinit_weights: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Reinitialize model weights randomly instead of loading pretrained weights"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
trainer_cls: str | None = Field(
|
trainer_cls: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
Reference in New Issue
Block a user