Files
axolotl/src/axolotl/utils/trainer.py
Wing Lian 11c0b5b256 bartch upgrade dependencies (#3299)
* upgrade dependencies

* don't use reset sessions

* downgrade transformers, upgrade other deps

* upgrade bnb to 0.49.0

* restore s3 cache

* explicit use local files w hub

* decompress and strip top level dir

* use 2 levels for strip components

* try to preserve permissions for symlinks

* use updated tar

* fix #3293 for distributed

* downgrade bnb

* fast fail after 4

* fix total tokens device

* patch accelerate CP/SP (#3309)

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-12-30 09:02:49 -05:00

733 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Module containing the Trainer class and related functions"""
import json
import math
import os
import random
from contextlib import contextmanager
from functools import partial
from tempfile import NamedTemporaryFile
from typing import List, Optional
import numpy as np
import torch
import torch.cuda
from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
@torch.jit.script
def weighted_cross_entropy(
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
):
# Flatten the logits, labels, and weights tensors
logits = logits.view(
-1, logits.size(-1)
) # logits becomes of shape [batch_size*sequence_length, vocab_size]
labels = labels.view(-1) # labels becomes of shape [batch_size*sequence_length]
weights = weights.view(-1) # weights becomes of shape [batch_size*sequence_length]
# Compute the unweighted cross entropy loss
losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
# Apply the weights to the losses and compute their sum
return (weights * losses).sum()
@torch.jit.script
def create_weighted_mask(labels: torch.Tensor):
# Check if the tensor is 2D. If not, unsqueeze it to make it 2D
if len(labels.shape) == 1:
labels = labels.unsqueeze(0)
weights = torch.zeros_like(labels).float()
for i in range(labels.shape[0]):
mask = labels[i] != -100
# Create a tensor to track group ids
group_ids = torch.zeros_like(labels[i]).int()
curr_group_id = 0
for j in range(1, len(labels[i])):
if mask[j] and not mask[j - 1]: # switch from masked to unmasked label
curr_group_id += 1 # start new group
group_ids[j] = (
curr_group_id if mask[j] else 0
) # assign group id if unmasked label
# Count only unmasked labels in each group
group_counts = torch.bincount(group_ids[mask])
mask_weights = torch.zeros_like(labels[i]).float()
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
weights[i] = mask_weights
return weights.squeeze() # squeeze the output to match the input dimension
def trainer_weighted_loss(model_output, labels, shift_labels=True):
logits = (
model_output["logits"] if isinstance(model_output, dict) else model_output[0]
)
if shift_labels:
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
weights = create_weighted_mask(labels)
return weighted_cross_entropy(logits, labels, weights)
@contextmanager
def disable_datasets_caching():
try:
disable_caching()
yield
finally:
enable_caching()
def add_position_ids(sample):
"""
Handle both single-example and batched data.
- single example: sample['input_ids'] is a list[int]
- batched data: sample['input_ids'] is a list[list[int]]
"""
# Return sample unchanged if "input_ids" is not present, or is empty
if "input_ids" not in sample or not sample["input_ids"]:
return sample
input_ids = sample["input_ids"]
# If first element is an int, its a single example
# If first element is a list, its a batch
if isinstance(input_ids[0], int):
# ---- SINGLE EXAMPLE ----
seq_len = len(input_ids)
# Position IDs for a single example
# As a list
sample["position_ids"] = list(range(seq_len))
sample["length"] = seq_len
else:
# ---- BATCHED EXAMPLES ----
# input_ids is a list of lists
position_ids_batch = []
lengths_batch = []
for seq in input_ids:
seq_len = len(seq)
position_ids_batch.append(list(range(seq_len)))
lengths_batch.append(seq_len)
# Now store them back
sample["position_ids"] = position_ids_batch
sample["length"] = lengths_batch
return sample
def add_pose_position_ids(
sample,
max_context_len=32768,
split_on_token_ids: Optional[List[int]] = None,
chunks: int = 2,
):
"""
use the PoSE technique to extend the context length by randomly skipping
positions in the context. We only want to skip right before tokens in
the split_on_token_ids list. We should attempt to randomly distribute
the skips, but we don't need the final position_ids to be the full
context_len. There may be multiple turns in the context, so we want to
make sure we take into account the maximum possible number of skips
remaining in each sample.
"""
input_ids = sample["input_ids"]
sample_len = len(input_ids)
max_skips = max_context_len - sample_len
if split_on_token_ids is None:
split_on_token_ids = []
if split_on_token_ids:
split_indices = [
i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids
]
else:
chunk_len = sample_len // chunks
split_indices = [i * chunk_len for i in range(1, chunks)]
split_indices.append(len(input_ids)) # make sure we go to the end of the sample
if split_indices[0] < 2:
# drop the first split index if it's too close to the beginning
split_indices = split_indices[1:]
position_ids = []
prev_index = 0
total_skips = 0
for split_index in split_indices:
num_skips = (
random.randint(0, max_skips) # nosec B311
if prev_index != 0 and max_skips
else 0
)
max_skips -= num_skips
total_skips += num_skips
segment_position_ids = list(
range(prev_index + total_skips, split_index + total_skips)
)
position_ids.extend(segment_position_ids)
prev_index = split_index
sample["sequence_len"] = position_ids[-1]
position_ids = torch.tensor(position_ids)
sample["position_ids"] = position_ids
sample["length"] = len(position_ids)
assert len(position_ids) == len(input_ids)
return sample
def add_length(sample):
sample["length"] = len(sample["input_ids"])
return sample
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False):
"""
Drop samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
Works for both single-example (list[int]) or batched (list[list[int]]).
If raise_on_drop is set, the code raises a ValueError if a sample is
encountered that is too long and would have been dropped.
"""
min_sequence_len = min_sequence_len or 2
input_ids = sample["input_ids"]
# Edge case: if input_ids is empty
if not input_ids:
# Decide if you want to drop or keep empty. Let's drop.
return False
# Check if single example or batched by looking at the first element
if isinstance(input_ids[0], int):
# Single example (input_ids is a list of int)
length = len(input_ids)
if raise_on_drop and length > sequence_len:
raise ValueError(
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
)
return min_sequence_len <= length <= sequence_len
# Batched (input_ids is a list of lists)
results = []
for seq in input_ids:
length = len(seq)
if raise_on_drop and length > sequence_len:
raise ValueError(
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
)
results.append(min_sequence_len <= length <= sequence_len)
return results
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
if drop_attn_mask:
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")
if cfg.model_config_type in ["falcon", "mistral"]:
LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("token_type_ids")
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")
def drop_no_trainable_tokens(sample):
"""
Drop samples if all labels are -100 (i.e., zero trainable tokens).
Works for both single-example or batched input.
"""
labels = sample["labels"]
if not labels:
return True
# Check if single example or batch
# If first element is an int, we assume a single example
# If it's a list, we assume we're dealing with a batch
if isinstance(labels[0], int):
# Single example: return a single bool
return np.any(labels != -100)
# Batched: 'labels' is a list of lists
# Return a list of booleans, one per sub-list
results = [np.any(row_labels != -100) for row_labels in labels]
return results
try:
prior_len = len(train_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
filter_map_kwargs = {}
if not isinstance(train_dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
train_dataset = train_dataset.filter(
drop_no_trainable_tokens,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(
f"Dropped {dropped} samples with no trainable tokens from train dataset"
)
if eval_dataset:
try:
prior_len = len(eval_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
eval_dataset = eval_dataset.filter(
drop_no_trainable_tokens,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(eval_dataset)
if dropped:
LOG.warning(
f"Dropped {dropped} samples with no trainable tokens from eval dataset"
)
if cfg.group_by_length:
train_dataset = train_dataset.map(
add_length,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Group By Length",
)
if cfg.use_pose:
pose_kwargs = {}
if cfg.pose_num_chunks is not None:
pose_kwargs["chunks"] = cfg.pose_num_chunks
pose_fn = partial(
add_pose_position_ids,
max_context_len=cfg.pose_max_context_len,
split_on_token_ids=cfg.pose_split_on_token_ids,
**pose_kwargs,
)
train_dataset = train_dataset.map(
pose_fn,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
train_dataset = train_dataset.sort("sequence_len")
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
pose_fn,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
train_dataset = train_dataset.map(
add_position_ids,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if cfg.eval_sample_packing:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids,
**filter_map_kwargs,
**drop_long_kwargs,
)
return train_dataset, eval_dataset
def process_pretraining_datasets_for_packing(
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
):
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
train_dataset = train_dataset.filter(
drop_long,
desc="Dropping Long Sequences",
load_from_cache_file=False,
)
if not skip_position_ids:
train_dataset = train_dataset.map(
add_position_ids,
batched=True,
desc="Add position_id column (Pretraining Sample Packing)",
)
if drop_attention_mask:
train_dataset = train_dataset.remove_columns("attention_mask")
return train_dataset
def calculate_total_num_steps(cfg, train_dataset, update=True):
if (
not cfg.total_num_tokens
and not cfg.skip_prepare_dataset
and not cfg.reward_model
):
total_num_tokens = np.sum(
train_dataset.select_columns("input_ids")
.to_pandas()["input_ids"]
.apply(len)
.values
)
LOG.debug(f"total_num_tokens: {total_num_tokens:_}")
if update:
cfg.total_num_tokens = total_num_tokens
skip_estimates = cfg.model_config_type == "mamba"
if (
not skip_estimates
and not cfg.total_supervised_tokens
and not cfg.skip_prepare_dataset
and not cfg.reward_model
):
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
.apply(lambda x: np.sum(np.array(x) != -100))
.sum()
)
LOG.debug(f"`total_supervised_tokens: {total_supervised_tokens:_}`")
if update:
cfg.total_supervised_tokens = total_supervised_tokens
if not skip_estimates and cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails
if cfg.sample_packing_eff_est:
total_num_steps = (
# match count to len est in dataloader
int(
math.floor(
0.99
* cfg.total_num_tokens
/ cfg.sample_packing_eff_est
/ cfg.sequence_len
// cfg.batch_size
)
- 1
)
* cfg.num_epochs
* cfg.context_parallel_size
* cfg.tensor_parallel_size
)
LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
)
else:
if cfg.flash_attention and not cfg.multipack_real_batches:
sampler_batch_size = 1
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
else:
sampler_batch_size = cfg.micro_batch_size
batch_max_len = cfg.sequence_len
if cfg.curriculum_sampling:
sampler = SequentialSampler(train_dataset)
else:
sampler = RandomSampler(train_dataset)
sampler = MultipackBatchSampler(
sampler=sampler,
lengths=get_dataset_lengths(train_dataset),
batch_size=sampler_batch_size,
batch_max_len=batch_max_len,
group_size=cfg.sample_packing_group_size,
bin_size=cfg.sample_packing_bin_size,
sequential=cfg.sample_packing_sequentially,
drop_last=True,
num_processes=cfg.dataset_prcoesses,
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
)
data_loader = DataLoader(
train_dataset.remove_columns(["length"]),
batch_sampler=sampler,
)
data_loader_len = max(
1, len(data_loader) * cfg.micro_batch_size // cfg.batch_size
)
LOG.debug(f"data_loader_len: {data_loader_len}")
# FIXME: is there a bug here somewhere? the total num steps depends
# on the agreed on value for sample_packing_eff_est
total_num_steps = int(
math.floor(
data_loader_len
* cfg.num_epochs
* cfg.context_parallel_size
* cfg.tensor_parallel_size
)
)
if cfg.dataloader_drop_last:
# drop the last batch for each epoch
total_num_steps -= int(math.ceil(cfg.num_epochs))
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
return max(estimates)
sample_packing_actual_eff_all = reduce_and_broadcast(
lambda: sampler.efficiency(),
calc_sample_packing_eff_est,
)
sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
)
if update:
cfg.sample_packing_eff_est = sample_packing_eff_est
LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
else:
total_num_steps = int(
math.ceil(
len(train_dataset)
* cfg.num_epochs
* cfg.context_parallel_size
* cfg.tensor_parallel_size
/ cfg.batch_size
)
)
LOG.debug(f"total_num_steps: {total_num_steps}")
return total_num_steps
def setup_torch_compile_env(cfg):
if cfg.torch_compile:
if not cfg.torch_compile_backend:
os.environ["ACCELERATE_DYNAMO_BACKEND"] = "INDUCTOR"
else:
os.environ["ACCELERATE_DYNAMO_BACKEND"] = cfg.torch_compile_backend.upper()
def setup_deepspeed_env(cfg, stage=None):
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
from axolotl.utils.distributed import distributed_state
if distributed_state and distributed_state.initialized:
raise RuntimeError(
"Distributed State already initialized before Deepspeed setup"
)
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
if isinstance(cfg.deepspeed, DictDefault):
with NamedTemporaryFile(
mode="w", delete=False, suffix=".json", prefix="deepspeed_config_"
) as temp_file:
temp_file.write(json.dumps(cfg.deepspeed.to_dict(), indent=4))
temp_file.close()
cfg.deepspeed = str(temp_file.name)
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
os.environ["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str(
cfg.gradient_accumulation_steps
)
if stage:
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
device_count = torch.cuda.device_count()
if device_count == 1:
os.environ.setdefault("WORLD_SIZE", "1")
os.environ.setdefault("LOCAL_RANK", "0")
os.environ.setdefault("MASTER_ADDR", "0.0.0.0") # nosec B104
os.environ.setdefault("MASTER_PORT", "29500")
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
# to model load.
if (
int(os.environ.get("WORLD_SIZE", "1")) == 1
and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1"
and cfg.use_ray is not True
):
os.environ["WORLD_SIZE"] = "1" # force it in case not set
os.environ["LOCAL_RANK"] = "0" # force it in case not set
os.environ["RANK"] = os.environ.get("LOCAL_RANK", "0")
import deepspeed.comm as dist
dist.init_distributed(
dist_backend="nccl", auto_mpi_discovery=False, dist_init_required=True
)
init_distributed_state()
# If we don't assign this, it doesn't actually get set in the accelerate weakref
_ = HfTrainerDeepSpeedConfig(cfg.deepspeed)
def setup_fsdp_envs(cfg):
os.environ["ACCELERATE_USE_FSDP"] = "true"
# TODO @SalmanMohammadi remove FSDP1 args in 0.12
if str(cfg.fsdp_version) == "2":
os.environ["FSDP_VERSION"] = "2"
if cfg.fsdp_config.activation_checkpointing:
os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true"
if cfg.fsdp_config.offload_params:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
if cfg.fsdp_config.sync_module_states:
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
if cfg.fsdp_config.cpu_ram_efficient_loading:
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "true"
if cfg.fsdp_config.use_orig_params:
os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
if cfg.fsdp_config.state_dict_type:
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.state_dict_type
if cfg.fsdp_config.cpu_offload_pin_memory is not None:
os.environ["FSDP_CPU_OFFLOAD_PIN_MEMORY"] = str(
cfg.fsdp_config.cpu_offload_pin_memory
).lower()
if cfg.fsdp_config.auto_wrap_policy:
os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.auto_wrap_policy
if cfg.fsdp_config.transformer_layer_cls_to_wrap:
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = (
cfg.fsdp_config.transformer_layer_cls_to_wrap
)
if cfg.fsdp_config.reshard_after_forward:
os.environ["FSDP_RESHARD_AFTER_FORWARD"] = "true"
def setup_parallelism_envs(cfg):
set_accelerate_parallelism_config = False
if cfg.tensor_parallel_size and cfg.tensor_parallel_size > 1:
set_accelerate_parallelism_config = True
os.environ["PARALLELISM_CONFIG_TP_SIZE"] = str(cfg.tensor_parallel_size)
if cfg.dp_shard_size and cfg.dp_shard_size > 1:
set_accelerate_parallelism_config = True
os.environ["PARALLELISM_CONFIG_DP_SHARD_SIZE"] = str(cfg.dp_shard_size)
if cfg.dp_replicate_size and cfg.dp_replicate_size > 1:
set_accelerate_parallelism_config = True
os.environ["PARALLELISM_CONFIG_DP_REPLICATE_SIZE"] = str(cfg.dp_replicate_size)
if cfg.context_parallel_size and cfg.context_parallel_size > 1:
set_accelerate_parallelism_config = True
os.environ["PARALLELISM_CONFIG_CP_SIZE"] = str(cfg.context_parallel_size)
os.environ["ACCELERATE_ALLOW_CP_STANDALONE"] = "true"
from axolotl.monkeypatch.accelerate.parallelism_config import patch_prepare_cp
patch_prepare_cp()
if set_accelerate_parallelism_config:
os.environ["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"
def prepare_optim_env(cfg):
if not check_cuda_p2p_ib_support():
if os.getenv("NCCL_P2P_DISABLE") is None:
LOG.warning("P2P support not detected, setting `NCCL_P2P_DISABLE=1`")
os.environ["NCCL_P2P_DISABLE"] = "1"
# TODO @SalmanMohammadi remove the cfg.fsdp check in 0.12
if cfg.fsdp or cfg.fsdp_config:
cfg.fsdp = True if not cfg.fsdp else cfg.fsdp
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
stage = None
deepspeed_config = None
# check if the cfg.deepspeed is a file
if isinstance(cfg.deepspeed, DictDefault):
deepspeed_config = cfg.deepspeed
elif os.path.isfile(cfg.deepspeed):
# parse with json
with open(cfg.deepspeed, "r", encoding="utf-8") as fin:
deepspeed_config = json.load(fin)
if deepspeed_config:
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
setup_deepspeed_env(cfg, stage=stage)
setup_parallelism_envs(cfg)
setup_torch_compile_env(cfg)
if cfg.fp8:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
elif (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
else:
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
def setup_trainer(
cfg,
train_dataset,
eval_dataset,
model,
tokenizer,
processor,
total_num_steps,
model_ref=None,
peft_config=None,
):
"""
Helper method for instantiating and building a (causal or RLHF) trainer.
Args:
cfg: Axolotl config object containing training parameters.
train_dataset: Dataset to use for training.
eval_dataset: Dataset to use for evaluation.
model: The model to train.
tokenizer: Tokenizer for processing text input.
processor: Processor for data preparation.
total_num_steps: The total number of training steps.
model_ref: Optional reference model for RLHF training. Default is None.
peft_config: Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None.
Returns:
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters.
"""
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
if cfg.rl:
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.model_ref = model_ref
trainer_builder.peft_config = peft_config
else:
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset
return trainer_builder.build(total_num_steps)