Compare commits
1 Commits
kd-logprob
...
kto_fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92c217677c |
@@ -40,7 +40,6 @@ from trl.trainer.utils import RewardDataCollatorWithPadding
|
|||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers.base import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
AxolotlKTOTrainer,
|
|
||||||
AxolotlMambaTrainer,
|
AxolotlMambaTrainer,
|
||||||
AxolotlORPOTrainer,
|
AxolotlORPOTrainer,
|
||||||
AxolotlPRMTrainer,
|
AxolotlPRMTrainer,
|
||||||
@@ -51,6 +50,7 @@ from axolotl.core.trainers.base import (
|
|||||||
from axolotl.core.trainers.dpo import DPOStrategy
|
from axolotl.core.trainers.dpo import DPOStrategy
|
||||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||||
|
from axolotl.core.trainers.kto import AxolotlKTOTrainer
|
||||||
from axolotl.core.training_args import (
|
from axolotl.core.training_args import (
|
||||||
AxolotlCPOConfig,
|
AxolotlCPOConfig,
|
||||||
AxolotlKTOConfig,
|
AxolotlKTOConfig,
|
||||||
@@ -751,12 +751,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.kd_ce_alpha is not None:
|
if self.cfg.kd_ce_alpha is not None:
|
||||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||||
if self.cfg.kd_ce_alpha_end is not None:
|
|
||||||
training_arguments_kwargs["kd_ce_alpha_end"] = self.cfg.kd_ce_alpha_end
|
|
||||||
if self.cfg.kd_alpha is not None:
|
if self.cfg.kd_alpha is not None:
|
||||||
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
|
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
|
||||||
if self.cfg.kd_alpha_end is not None:
|
|
||||||
training_arguments_kwargs["kd_alpha_end"] = self.cfg.kd_alpha_end
|
|
||||||
if self.cfg.kd_temperature is not None:
|
if self.cfg.kd_temperature is not None:
|
||||||
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
||||||
if self.cfg.kd_zscore_base_temp is not None:
|
if self.cfg.kd_zscore_base_temp is not None:
|
||||||
|
|||||||
@@ -20,9 +20,10 @@ from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sequential
|
|||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
from trl import CPOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
|
from axolotl.core.trainers.kto import AxolotlKTOTrainer
|
||||||
from axolotl.integrations.base import BaseOptimizerFactory
|
from axolotl.integrations.base import BaseOptimizerFactory
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
@@ -874,14 +875,6 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
|||||||
tag_names = ["axolotl", "orpo"]
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base KTOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base CPOTrainer for axolotl helpers
|
Extend the base CPOTrainer for axolotl helpers
|
||||||
|
|||||||
7
src/axolotl/core/trainers/kto/__init__.py
Normal file
7
src/axolotl/core/trainers/kto/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
KTO package initialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from axolotl.core.trainers.kto.trainer import AxolotlKTOTrainer
|
||||||
|
|
||||||
|
__all__ = ["AxolotlKTOTrainer"]
|
||||||
512
src/axolotl/core/trainers/kto/trainer.py
Normal file
512
src/axolotl/core/trainers/kto/trainer.py
Normal file
@@ -0,0 +1,512 @@
|
|||||||
|
"""
|
||||||
|
KTO trainer implementation for Axolotl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from typing import Any, Callable, Literal, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from datasets import Dataset
|
||||||
|
from torch.utils.data import DataLoader, SequentialSampler
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
BaseImageProcessor,
|
||||||
|
DataCollator,
|
||||||
|
FeatureExtractionMixin,
|
||||||
|
PreTrainedModel,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
ProcessorMixin,
|
||||||
|
Trainer,
|
||||||
|
TrainerCallback,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
from transformers.trainer_utils import EvalLoopOutput
|
||||||
|
from trl import KTOTrainer
|
||||||
|
from trl.trainer.kto_config import KTOConfig
|
||||||
|
from trl.trainer.utils import KTODataCollatorWithPadding, pad_to_length
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import SchedulerMixin
|
||||||
|
|
||||||
|
# Check if PEFT is available
|
||||||
|
try:
|
||||||
|
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training, peft_module_casting_to_bf16
|
||||||
|
is_peft_available = True
|
||||||
|
except ImportError:
|
||||||
|
is_peft_available = False
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.core.trainers.kto")
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlKTOTrainer(SchedulerMixin, Trainer):
|
||||||
|
"""
|
||||||
|
Extend the base KTOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union[PreTrainedModel, nn.Module, str] = None,
|
||||||
|
args: KTOConfig = None,
|
||||||
|
train_dataset: Optional[Dataset] = None,
|
||||||
|
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||||
|
processing_class: Optional[
|
||||||
|
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||||
|
] = None,
|
||||||
|
data_collator: Optional[DataCollator] = None,
|
||||||
|
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||||
|
callbacks: Optional[list[TrainerCallback]] = None,
|
||||||
|
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||||
|
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||||
|
peft_config: Optional[dict] = None,
|
||||||
|
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
||||||
|
dataset_tags=None,
|
||||||
|
model_adapter_name: Optional[str] = None,
|
||||||
|
ref_adapter_name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.dataset_tags = dataset_tags
|
||||||
|
self._tag_names = ["trl", "kto"]
|
||||||
|
if hasattr(self, "tag_names"):
|
||||||
|
self._tag_names.extend(self.tag_names)
|
||||||
|
|
||||||
|
if type(args) is TrainingArguments:
|
||||||
|
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
|
||||||
|
|
||||||
|
if args.model_init_kwargs is None:
|
||||||
|
model_init_kwargs = {}
|
||||||
|
elif not isinstance(model, str):
|
||||||
|
raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
|
||||||
|
else:
|
||||||
|
model_init_kwargs = args.model_init_kwargs
|
||||||
|
torch_dtype = model_init_kwargs.get("torch_dtype")
|
||||||
|
if torch_dtype is not None:
|
||||||
|
# Convert to `torch.dtype` if an str is passed
|
||||||
|
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
||||||
|
torch_dtype = getattr(torch, torch_dtype)
|
||||||
|
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
||||||
|
)
|
||||||
|
model_init_kwargs["torch_dtype"] = torch_dtype
|
||||||
|
|
||||||
|
if args.ref_model_init_kwargs is None:
|
||||||
|
ref_model_init_kwargs = {}
|
||||||
|
elif not isinstance(ref_model, str):
|
||||||
|
raise ValueError(
|
||||||
|
"You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ref_model_init_kwargs = args.ref_model_init_kwargs
|
||||||
|
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
|
||||||
|
if torch_dtype is not None:
|
||||||
|
# Convert to `torch.dtype` if an str is passed
|
||||||
|
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
||||||
|
torch_dtype = getattr(torch, torch_dtype)
|
||||||
|
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
||||||
|
)
|
||||||
|
ref_model_init_kwargs["torch_dtype"] = torch_dtype
|
||||||
|
|
||||||
|
if isinstance(model, str):
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
||||||
|
|
||||||
|
if isinstance(ref_model, str):
|
||||||
|
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
|
||||||
|
|
||||||
|
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
||||||
|
# has been called in order to properly call autocast if needed.
|
||||||
|
self._peft_has_been_casted_to_bf16 = False
|
||||||
|
|
||||||
|
if not is_peft_available() and peft_config is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
||||||
|
)
|
||||||
|
elif is_peft_available() and peft_config is not None:
|
||||||
|
# if model is a peft model and we have a peft_config, we merge and unload it first
|
||||||
|
if isinstance(model, PeftModel):
|
||||||
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
|
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
||||||
|
_support_gc_kwargs = hasattr(
|
||||||
|
args, "gradient_checkpointing_kwargs"
|
||||||
|
) and "gradient_checkpointing_kwargs" in list(
|
||||||
|
inspect.signature(prepare_model_for_kbit_training).parameters
|
||||||
|
)
|
||||||
|
|
||||||
|
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
||||||
|
|
||||||
|
if _support_gc_kwargs:
|
||||||
|
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
||||||
|
|
||||||
|
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
||||||
|
elif getattr(args, "gradient_checkpointing", False):
|
||||||
|
# For backward compatibility with older versions of transformers
|
||||||
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
else:
|
||||||
|
|
||||||
|
def make_inputs_require_grad(module, input, output):
|
||||||
|
output.requires_grad_(True)
|
||||||
|
|
||||||
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||||
|
|
||||||
|
# get peft model with the given config
|
||||||
|
model = get_peft_model(model, peft_config)
|
||||||
|
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
||||||
|
peft_module_casting_to_bf16(model)
|
||||||
|
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
||||||
|
self._peft_has_been_casted_to_bf16 = True
|
||||||
|
|
||||||
|
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
||||||
|
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
||||||
|
# fail or completely fail.
|
||||||
|
elif getattr(args, "gradient_checkpointing", False):
|
||||||
|
# For backward compatibility with older versions of transformers
|
||||||
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
else:
|
||||||
|
|
||||||
|
def make_inputs_require_grad(module, input, output):
|
||||||
|
output.requires_grad_(True)
|
||||||
|
|
||||||
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||||
|
|
||||||
|
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
||||||
|
raise ValueError(
|
||||||
|
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
||||||
|
" Please install `wandb` or `comet-ml` to resolve."
|
||||||
|
)
|
||||||
|
|
||||||
|
if model is not None:
|
||||||
|
self.is_encoder_decoder = model.config.is_encoder_decoder
|
||||||
|
elif args.is_encoder_decoder is None:
|
||||||
|
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
||||||
|
else:
|
||||||
|
self.is_encoder_decoder = args.is_encoder_decoder
|
||||||
|
|
||||||
|
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
||||||
|
self.model_adapter_name = model_adapter_name
|
||||||
|
self.ref_adapter_name = ref_adapter_name
|
||||||
|
|
||||||
|
if ref_model:
|
||||||
|
self.ref_model = ref_model
|
||||||
|
elif self.is_peft_model or args.precompute_ref_log_probs:
|
||||||
|
# The `model` with adapters turned off will be used as the reference model
|
||||||
|
self.ref_model = None
|
||||||
|
else:
|
||||||
|
self.ref_model = create_reference_model(model)
|
||||||
|
|
||||||
|
if processing_class is None:
|
||||||
|
raise ValueError(
|
||||||
|
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
|
||||||
|
)
|
||||||
|
if args.max_length is None:
|
||||||
|
warnings.warn(
|
||||||
|
"When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
|
||||||
|
" it will be set to `512` by default, but you should do it yourself in the future.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
max_length = 512
|
||||||
|
if args.max_length is not None:
|
||||||
|
max_length = args.max_length
|
||||||
|
|
||||||
|
if args.max_prompt_length is None:
|
||||||
|
warnings.warn(
|
||||||
|
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
|
||||||
|
" it will be set to `128` by default, but you should do it yourself in the future.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
max_prompt_length = 128
|
||||||
|
if args.max_prompt_length is not None:
|
||||||
|
max_prompt_length = args.max_prompt_length
|
||||||
|
|
||||||
|
max_completion_length = None
|
||||||
|
if args.max_completion_length is None and self.is_encoder_decoder:
|
||||||
|
warnings.warn(
|
||||||
|
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
|
||||||
|
" it will be set to `128` by default, but you should do it yourself in the future.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
max_completion_length = 128
|
||||||
|
if args.max_completion_length is not None and self.is_encoder_decoder:
|
||||||
|
max_completion_length = args.max_completion_length
|
||||||
|
|
||||||
|
if data_collator is None:
|
||||||
|
data_collator = DPODataCollatorWithPadding(
|
||||||
|
pad_token_id=processing_class.pad_token_id,
|
||||||
|
label_pad_token_id=args.label_pad_token_id,
|
||||||
|
is_encoder_decoder=self.is_encoder_decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.remove_unused_columns:
|
||||||
|
args.remove_unused_columns = False
|
||||||
|
# warn users
|
||||||
|
warnings.warn(
|
||||||
|
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
|
||||||
|
" we have set it for you, but you should do it yourself in the future.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.use_dpo_data_collator = True
|
||||||
|
else:
|
||||||
|
self.use_dpo_data_collator = False
|
||||||
|
|
||||||
|
# Disable dropout in the model and reference model
|
||||||
|
if args.disable_dropout:
|
||||||
|
disable_dropout_in_model(model)
|
||||||
|
if self.ref_model is not None:
|
||||||
|
disable_dropout_in_model(self.ref_model)
|
||||||
|
|
||||||
|
self.loss_type = args.loss_type
|
||||||
|
self.max_length = max_length
|
||||||
|
self.generate_during_eval = args.generate_during_eval
|
||||||
|
self.label_pad_token_id = args.label_pad_token_id
|
||||||
|
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
||||||
|
self.max_prompt_length = max_prompt_length
|
||||||
|
self.truncation_mode = args.truncation_mode
|
||||||
|
self.max_completion_length = max_completion_length
|
||||||
|
self.processing_class = processing_class
|
||||||
|
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
||||||
|
|
||||||
|
# Not all losses require a KL calculation
|
||||||
|
self.calculate_KL = True
|
||||||
|
if self.loss_type in ["apo_zero_unpaired"]:
|
||||||
|
self.calculate_KL = False
|
||||||
|
|
||||||
|
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
||||||
|
# keep track of first called to avoid computation of future calls
|
||||||
|
self._precomputed_train_ref_log_probs = False
|
||||||
|
self._precomputed_eval_ref_log_probs = False
|
||||||
|
|
||||||
|
# metric
|
||||||
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
|
# KTO parameter
|
||||||
|
self.beta = args.beta
|
||||||
|
self.desirable_weight = args.desirable_weight
|
||||||
|
self.undesirable_weight = args.undesirable_weight
|
||||||
|
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
||||||
|
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
||||||
|
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
||||||
|
warnings.warn(
|
||||||
|
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
||||||
|
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
||||||
|
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
||||||
|
"loss.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
||||||
|
# input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
|
||||||
|
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
|
||||||
|
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
||||||
|
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
||||||
|
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
||||||
|
# issued.
|
||||||
|
model.warnings_issued["estimate_tokens"] = True
|
||||||
|
|
||||||
|
# Compute that only on the main process for faster data processing.
|
||||||
|
# see: https://github.com/huggingface/trl/pull/1255
|
||||||
|
with PartialState().local_main_process_first():
|
||||||
|
# Extract the prompt if needed
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
|
||||||
|
)
|
||||||
|
# Unpair the dataset if needed
|
||||||
|
train_dataset = maybe_unpair_preference_dataset(
|
||||||
|
train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
|
||||||
|
)
|
||||||
|
# Apply the chat template if needed
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
maybe_apply_chat_template,
|
||||||
|
fn_kwargs={"tokenizer": processing_class},
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
desc="Applying chat template to train dataset",
|
||||||
|
)
|
||||||
|
if eval_dataset is not None:
|
||||||
|
eval_dataset = eval_dataset.map(
|
||||||
|
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
|
||||||
|
)
|
||||||
|
eval_dataset = maybe_unpair_preference_dataset(
|
||||||
|
eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
|
||||||
|
)
|
||||||
|
eval_dataset = eval_dataset.map(
|
||||||
|
maybe_apply_chat_template,
|
||||||
|
fn_kwargs={"tokenizer": processing_class},
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
desc="Applying chat template to eval dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tokenize and prepare the training datasets
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
_tokenize,
|
||||||
|
batched=True,
|
||||||
|
fn_kwargs={"tokenizer": self.processing_class},
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
desc="Tokenizing train dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
fn_kwargs = {
|
||||||
|
"prefix": "",
|
||||||
|
"is_encoder_decoder": self.is_encoder_decoder,
|
||||||
|
"tokenizer": self.processing_class,
|
||||||
|
"max_length": self.max_length,
|
||||||
|
"truncation_mode": self.truncation_mode,
|
||||||
|
"label_pad_token_id": self.label_pad_token_id,
|
||||||
|
"max_prompt_length": self.max_prompt_length,
|
||||||
|
"max_completion_length": self.max_completion_length,
|
||||||
|
}
|
||||||
|
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
_process_tokens,
|
||||||
|
fn_kwargs=fn_kwargs,
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
desc="Processing tokenized train dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tokenize and prepare the eval datasets
|
||||||
|
if eval_dataset is not None:
|
||||||
|
eval_dataset = eval_dataset.map(
|
||||||
|
_tokenize,
|
||||||
|
fn_kwargs={"tokenizer": self.processing_class},
|
||||||
|
batched=True,
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
desc="Tokenizing eval dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_dataset = eval_dataset.map(
|
||||||
|
_process_tokens,
|
||||||
|
fn_kwargs=fn_kwargs,
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
desc="Processing tokenized eval dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get KL datasets if needed
|
||||||
|
if self.calculate_KL:
|
||||||
|
if args.per_device_train_batch_size <= 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
|
||||||
|
)
|
||||||
|
|
||||||
|
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
|
||||||
|
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
|
||||||
|
train_kl_dataset = train_dataset.map(
|
||||||
|
_get_kl_dataset,
|
||||||
|
batched=True,
|
||||||
|
batch_size=args.per_device_train_batch_size,
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
desc="Extracting KL train dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
fn_kwargs["prefix"] = "KL_"
|
||||||
|
train_kl_dataset = train_kl_dataset.map(
|
||||||
|
_process_tokens,
|
||||||
|
fn_kwargs=fn_kwargs,
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
|
||||||
|
desc="Processing tokenized train KL dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
# merge the datasets
|
||||||
|
train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
|
||||||
|
|
||||||
|
if eval_dataset is not None:
|
||||||
|
# Get KL dataset
|
||||||
|
eval_kl_dataset = eval_dataset.map(
|
||||||
|
_get_kl_dataset,
|
||||||
|
batched=True,
|
||||||
|
batch_size=args.per_device_train_batch_size,
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
desc="Extracting eval KL dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_kl_dataset = eval_kl_dataset.map(
|
||||||
|
_process_tokens,
|
||||||
|
fn_kwargs=fn_kwargs,
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
|
||||||
|
desc="Processing tokenized eval KL dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
# merge the datasets
|
||||||
|
eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
|
||||||
|
|
||||||
|
# calculate dataset desirability balance
|
||||||
|
num_desirable = max(sum(train_dataset["label"]), 1)
|
||||||
|
num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
|
||||||
|
|
||||||
|
if num_desirable != num_undesirable:
|
||||||
|
# The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
|
||||||
|
des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
|
||||||
|
des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
|
||||||
|
und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
|
||||||
|
und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
|
||||||
|
|
||||||
|
des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
|
||||||
|
und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
|
||||||
|
|
||||||
|
if not (des_weight_in_range or und_weight_in_range):
|
||||||
|
warnings.warn(
|
||||||
|
"You have different amounts of desirable/positive and undesirable/negative examples but the "
|
||||||
|
"weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
|
||||||
|
f"on your data, we recommend EITHER "
|
||||||
|
f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
|
||||||
|
f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
|
||||||
|
"See the documentation on how to optimally set these weights.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
model=model,
|
||||||
|
args=args,
|
||||||
|
data_collator=data_collator,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
processing_class=processing_class,
|
||||||
|
model_init=model_init,
|
||||||
|
compute_metrics=compute_metrics,
|
||||||
|
callbacks=callbacks,
|
||||||
|
optimizers=optimizers,
|
||||||
|
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
||||||
|
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
||||||
|
# self.model_accepts_loss_kwargs to False to enable scaling.
|
||||||
|
self.model_accepts_loss_kwargs = False
|
||||||
|
|
||||||
|
# Add tags for models that have been loaded with the correct transformers version
|
||||||
|
if hasattr(self.model, "add_model_tags"):
|
||||||
|
self.model.add_model_tags(self._tag_names)
|
||||||
|
|
||||||
|
if not hasattr(self, "accelerator"):
|
||||||
|
raise AttributeError(
|
||||||
|
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
||||||
|
if self.is_deepspeed_enabled:
|
||||||
|
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.ref_model is None:
|
||||||
|
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
||||||
|
raise ValueError(
|
||||||
|
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.is_deepspeed_enabled:
|
||||||
|
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||||
|
else:
|
||||||
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
@@ -34,12 +34,3 @@ class KDPlugin(BasePlugin):
|
|||||||
|
|
||||||
return AxolotlKDTrainer
|
return AxolotlKDTrainer
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
|
||||||
callbacks = []
|
|
||||||
if cfg.kd_trainer:
|
|
||||||
from .callbacks import KDAlphaSchedulerCallback
|
|
||||||
|
|
||||||
callbacks.append(KDAlphaSchedulerCallback())
|
|
||||||
|
|
||||||
return callbacks
|
|
||||||
|
|||||||
@@ -30,8 +30,6 @@ class KDArgs(BaseModel):
|
|||||||
float
|
float
|
||||||
] = None # loss coefficient for cross-entropy loss during KD
|
] = None # loss coefficient for cross-entropy loss during KD
|
||||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||||
kd_ce_alpha_end: Optional[float] = None # end value for kd_ce_alpha
|
|
||||||
kd_alpha_end: Optional[float] = None # end value for kd_alpha
|
|
||||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||||
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
||||||
kd_top_k_before_softmax: Optional[
|
kd_top_k_before_softmax: Optional[
|
||||||
|
|||||||
@@ -1,28 +0,0 @@
|
|||||||
from transformers import TrainerCallback
|
|
||||||
|
|
||||||
|
|
||||||
class KDAlphaSchedulerCallback(TrainerCallback):
|
|
||||||
"""Callback to for scheduling KD alpha during training."""
|
|
||||||
|
|
||||||
def on_epoch_begin(
|
|
||||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
if int(state.epoch) == 0:
|
|
||||||
state.kd_alpha = args.kd_alpha
|
|
||||||
state.kd_ce_alpha = args.kd_ce_alpha
|
|
||||||
elif int(state.epoch) == state.num_train_epochs - 1:
|
|
||||||
if args.kd_alpha_end is not None:
|
|
||||||
control.kd_alpha = args.kd_alpha_end
|
|
||||||
if args.kd_ce_alpha_end is not None:
|
|
||||||
control.kd_ce_alpha = args.kd_ce_alpha_end
|
|
||||||
else:
|
|
||||||
epoch_steps = state.num_train_epochs - 1
|
|
||||||
scale = int(state.epoch) / epoch_steps
|
|
||||||
if args.kd_alpha_end is not None:
|
|
||||||
control.kd_alpha = (
|
|
||||||
args.kd_alpha + (args.kd_alpha_end - args.kd_alpha) * scale
|
|
||||||
)
|
|
||||||
if args.kd_ce_alpha_end is not None:
|
|
||||||
control.kd_ce_alpha = (
|
|
||||||
args.kd_ce_alpha + (args.kd_ce_alpha_end - args.kd_ce_alpha) * scale
|
|
||||||
)
|
|
||||||
@@ -62,16 +62,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
Transform logprobs to target format for KD training
|
Transform logprobs to target format for KD training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if "target_logprobs" in sample.keys() and "target_token_ids" in sample.keys():
|
logprobs = sample.pop(self.logprobs_field)
|
||||||
logprobs = sample.pop("target_logprobs")
|
|
||||||
token_ids = sample.pop("target_token_ids")
|
|
||||||
else:
|
|
||||||
logprobs = sample.pop(self.logprobs_field)
|
|
||||||
token_ids = [None] * len(logprobs)
|
|
||||||
|
|
||||||
target_seq_len = len(logprobs)
|
target_seq_len = len(logprobs)
|
||||||
input_seq_len = len(sample["input_ids"])
|
input_seq_len = len(sample["input_ids"])
|
||||||
target_padding_len = input_seq_len - target_seq_len
|
input_padding_len = input_seq_len - target_seq_len
|
||||||
# get non-zero top-k (prune None logprobs from vllm data step)
|
# get non-zero top-k (prune None logprobs from vllm data step)
|
||||||
top_k_vals = [
|
top_k_vals = [
|
||||||
len(logprobs[i])
|
len(logprobs[i])
|
||||||
@@ -88,11 +82,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
target_token_ids = []
|
target_token_ids = []
|
||||||
target_mask = []
|
target_mask = []
|
||||||
|
|
||||||
if target_padding_len < 0:
|
if input_padding_len < 0:
|
||||||
# logprobs is longer than target_seq_len,
|
# logprobs is longer than target_seq_len,
|
||||||
# so we need to slice from the left/beginning of logprobs
|
# so we need to slice from the left/beginning of logprobs
|
||||||
logprobs = logprobs[:-input_seq_len]
|
logprobs = logprobs[:-input_seq_len]
|
||||||
target_padding_len = 0
|
input_padding_len = 0
|
||||||
# target_seq_len = input_seq_len
|
# target_seq_len = input_seq_len
|
||||||
|
|
||||||
# truncate the second dimension of the logprobs to top_k
|
# truncate the second dimension of the logprobs to top_k
|
||||||
@@ -104,37 +98,33 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
|
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
|
||||||
# otherwise, we need to shift in the trainer
|
# otherwise, we need to shift in the trainer
|
||||||
shift = 0
|
shift = 0
|
||||||
for _ in range(shift, target_padding_len):
|
for _ in range(shift, input_padding_len):
|
||||||
target_logprobs.append([-float("inf")] * top_k)
|
target_logprobs.append([-float("inf")] * top_k)
|
||||||
target_token_ids.append(list(range(top_k)))
|
target_token_ids.append(list(range(top_k)))
|
||||||
target_mask.append([0] * top_k)
|
target_mask.append([0] * top_k)
|
||||||
|
|
||||||
for position in range(target_padding_len, input_seq_len):
|
for position in range(input_padding_len, input_seq_len):
|
||||||
if sample["labels"][position] == -100:
|
if sample["labels"][position] == -100:
|
||||||
target_mask.append([0] * top_k)
|
target_mask.append([0] * top_k)
|
||||||
else:
|
else:
|
||||||
target_mask.append([1] * top_k)
|
target_mask.append([1] * top_k)
|
||||||
|
|
||||||
for token_pos_logprobs, token_pos_token_ids in zip(logprobs, token_ids):
|
for _, token_pos_logprobs in enumerate(logprobs):
|
||||||
# Initialize collections for logprobs and token_ids
|
# Initialize collections for logprobs and token_ids
|
||||||
position_logprobs = []
|
position_logprobs = []
|
||||||
position_token_ids = []
|
position_token_ids = []
|
||||||
|
|
||||||
# Process each token probability entry
|
# Process each token probability entry
|
||||||
if token_pos_token_ids is None:
|
for entry in token_pos_logprobs:
|
||||||
for entry in token_pos_logprobs:
|
# Extract logprob value
|
||||||
# Extract logprob value
|
logprob = entry["logprob"]
|
||||||
logprob = entry["logprob"]
|
|
||||||
|
|
||||||
# Parse token_id from the "token_id:###" format
|
# Parse token_id from the "token_id:###" format
|
||||||
token_id = int(entry["token"].split(":")[1])
|
token_id = int(entry["token"].split(":")[1])
|
||||||
|
|
||||||
# Append to our collections
|
# Append to our collections
|
||||||
position_logprobs.append(logprob)
|
position_logprobs.append(logprob)
|
||||||
position_token_ids.append(token_id)
|
position_token_ids.append(token_id)
|
||||||
else:
|
|
||||||
position_logprobs = token_pos_logprobs
|
|
||||||
position_token_ids = token_pos_token_ids
|
|
||||||
|
|
||||||
# Convert to a tensor for easier manipulation
|
# Convert to a tensor for easier manipulation
|
||||||
position_logprobs_tensor = torch.tensor(
|
position_logprobs_tensor = torch.tensor(
|
||||||
@@ -153,7 +143,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
teacher_probs_t2 = teacher_probs_t1**exponent
|
teacher_probs_t2 = teacher_probs_t1**exponent
|
||||||
else:
|
else:
|
||||||
teacher_probs_t2 = teacher_probs_t1
|
teacher_probs_t2 = teacher_probs_t1
|
||||||
|
|
||||||
# Re-normalize
|
# Re-normalize
|
||||||
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
||||||
dim=0, keepdim=True
|
dim=0, keepdim=True
|
||||||
|
|||||||
@@ -16,35 +16,17 @@
|
|||||||
KD trainer
|
KD trainer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from transformers import TrainerControl
|
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
|
||||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
||||||
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKDTrainerControl(TrainerControl):
|
|
||||||
kd_alpha: float = 1.0
|
|
||||||
kd_ce_alpha: float = 0.0
|
|
||||||
|
|
||||||
def state(self) -> dict:
|
|
||||||
state_val = super().state()
|
|
||||||
state_val["args"]["kd_alpha"] = self.kd_alpha
|
|
||||||
state_val["args"]["kd_ce_alpha"] = self.kd_ce_alpha
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKDTrainer(AxolotlTrainer):
|
class AxolotlKDTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
Custom trainer subclass for Knowledge Distillation (KD)
|
Custom trainer subclass for Knowledge Distillation (KD)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.kd_alpha = self.args.kd_alpha
|
|
||||||
self.kd_ce_alpha = self.args.kd_ce_alpha
|
|
||||||
self.control = AxolotlKDTrainerControl()
|
|
||||||
|
|
||||||
def _set_signature_columns_if_needed(self):
|
def _set_signature_columns_if_needed(self):
|
||||||
super()._set_signature_columns_if_needed()
|
super()._set_signature_columns_if_needed()
|
||||||
columns_to_add = []
|
columns_to_add = []
|
||||||
@@ -113,8 +95,9 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
|
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.kd_ce_alpha > 0:
|
if self.args.kd_ce_alpha > 0:
|
||||||
loss = self.kd_ce_alpha * outputs["loss"] + self.kd_alpha * loss_kd
|
kd_alpha = self.args.kd_alpha
|
||||||
|
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
||||||
else:
|
else:
|
||||||
loss = loss_kd
|
loss = loss_kd
|
||||||
# Save past state if it exists
|
# Save past state if it exists
|
||||||
|
|||||||
@@ -813,15 +813,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
)
|
)
|
||||||
except (FileNotFoundError, ConnectionError) as err:
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||||
# TODO if using deepspeed and it's a file, save deepspeed config too
|
|
||||||
if args.deepspeed and os.path.isfile(args.deepspeed):
|
|
||||||
LOG.info(f"DeepSpeed config has been saved to the WandB run.")
|
|
||||||
artifact = wandb.Artifact(
|
|
||||||
f"deepspeed-{wandb.run.id}", type="deepspeed-config"
|
|
||||||
)
|
|
||||||
artifact.add_file(args.deepspeed)
|
|
||||||
wandb.log_artifact(artifact)
|
|
||||||
wandb.save(args.deepspeed)
|
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -173,16 +173,10 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
else:
|
else:
|
||||||
try:
|
arrays = [
|
||||||
arrays = [
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
np.array(item[feature])
|
]
|
||||||
for item in features_
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
if feature in item
|
|
||||||
]
|
|
||||||
if arrays[0].dtype != "object":
|
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -121,6 +121,7 @@ def drop_long_rl_seq(
|
|||||||
|
|
||||||
|
|
||||||
def load_prepare_preference_datasets(cfg):
|
def load_prepare_preference_datasets(cfg):
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
def load_split(dataset_cfgs, _cfg):
|
def load_split(dataset_cfgs, _cfg):
|
||||||
split_datasets: List[Any] = []
|
split_datasets: List[Any] = []
|
||||||
use_auth_token = _cfg.hf_use_auth_token
|
use_auth_token = _cfg.hf_use_auth_token
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from peft import (
|
|||||||
PeftModelForCausalLM,
|
PeftModelForCausalLM,
|
||||||
prepare_model_for_kbit_training,
|
prepare_model_for_kbit_training,
|
||||||
)
|
)
|
||||||
|
from peft.tuners.lora import QuantLinear
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import ( # noqa: F401
|
from transformers import ( # noqa: F401
|
||||||
AddedToken,
|
AddedToken,
|
||||||
@@ -1359,7 +1360,7 @@ def load_llama_adapter(model, cfg):
|
|||||||
|
|
||||||
|
|
||||||
def find_all_linear_names(model):
|
def find_all_linear_names(model):
|
||||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
||||||
lora_module_names = set()
|
lora_module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ def fixture_cfg():
|
|||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"rl": True,
|
"rl": True,
|
||||||
"adam_beta1": 0.91,
|
"adam_beta1": 0.998,
|
||||||
"adam_beta2": 0.998,
|
"adam_beta2": 0.9,
|
||||||
"adam_epsilon": 0.00001,
|
"adam_epsilon": 0.00001,
|
||||||
"dataloader_num_workers": 1,
|
"dataloader_num_workers": 1,
|
||||||
"dataloader_pin_memory": True,
|
"dataloader_pin_memory": True,
|
||||||
@@ -60,8 +60,8 @@ class TestHFRLTrainerBuilder:
|
|||||||
def test_build_training_arguments(self, cfg, model, tokenizer):
|
def test_build_training_arguments(self, cfg, model, tokenizer):
|
||||||
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
||||||
training_arguments = builder.build_training_arguments(100)
|
training_arguments = builder.build_training_arguments(100)
|
||||||
assert training_arguments.adam_beta1 == 0.91
|
assert training_arguments.adam_beta1 == 0.998
|
||||||
assert training_arguments.adam_beta2 == 0.998
|
assert training_arguments.adam_beta2 == 0.9
|
||||||
assert training_arguments.adam_epsilon == 0.00001
|
assert training_arguments.adam_epsilon == 0.00001
|
||||||
assert training_arguments.dataloader_num_workers == 1
|
assert training_arguments.dataloader_num_workers == 1
|
||||||
assert training_arguments.dataloader_pin_memory is True
|
assert training_arguments.dataloader_pin_memory is True
|
||||||
|
|||||||
Reference in New Issue
Block a user