Compare commits
1 Commits
liger-065
...
kd-logprob
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8fc4c420a4 |
@@ -751,8 +751,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.kd_ce_alpha is not None:
|
||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||
if self.cfg.kd_ce_alpha_end is not None:
|
||||
training_arguments_kwargs["kd_ce_alpha_end"] = self.cfg.kd_ce_alpha_end
|
||||
if self.cfg.kd_alpha is not None:
|
||||
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
|
||||
if self.cfg.kd_alpha_end is not None:
|
||||
training_arguments_kwargs["kd_alpha_end"] = self.cfg.kd_alpha_end
|
||||
if self.cfg.kd_temperature is not None:
|
||||
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
||||
if self.cfg.kd_zscore_base_temp is not None:
|
||||
|
||||
@@ -34,3 +34,12 @@ class KDPlugin(BasePlugin):
|
||||
|
||||
return AxolotlKDTrainer
|
||||
return None
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
callbacks = []
|
||||
if cfg.kd_trainer:
|
||||
from .callbacks import KDAlphaSchedulerCallback
|
||||
|
||||
callbacks.append(KDAlphaSchedulerCallback())
|
||||
|
||||
return callbacks
|
||||
|
||||
@@ -30,6 +30,8 @@ class KDArgs(BaseModel):
|
||||
float
|
||||
] = None # loss coefficient for cross-entropy loss during KD
|
||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||
kd_ce_alpha_end: Optional[float] = None # end value for kd_ce_alpha
|
||||
kd_alpha_end: Optional[float] = None # end value for kd_alpha
|
||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
||||
kd_top_k_before_softmax: Optional[
|
||||
|
||||
28
src/axolotl/integrations/kd/callbacks.py
Normal file
28
src/axolotl/integrations/kd/callbacks.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from transformers import TrainerCallback
|
||||
|
||||
|
||||
class KDAlphaSchedulerCallback(TrainerCallback):
|
||||
"""Callback to for scheduling KD alpha during training."""
|
||||
|
||||
def on_epoch_begin(
|
||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||
):
|
||||
if int(state.epoch) == 0:
|
||||
state.kd_alpha = args.kd_alpha
|
||||
state.kd_ce_alpha = args.kd_ce_alpha
|
||||
elif int(state.epoch) == state.num_train_epochs - 1:
|
||||
if args.kd_alpha_end is not None:
|
||||
control.kd_alpha = args.kd_alpha_end
|
||||
if args.kd_ce_alpha_end is not None:
|
||||
control.kd_ce_alpha = args.kd_ce_alpha_end
|
||||
else:
|
||||
epoch_steps = state.num_train_epochs - 1
|
||||
scale = int(state.epoch) / epoch_steps
|
||||
if args.kd_alpha_end is not None:
|
||||
control.kd_alpha = (
|
||||
args.kd_alpha + (args.kd_alpha_end - args.kd_alpha) * scale
|
||||
)
|
||||
if args.kd_ce_alpha_end is not None:
|
||||
control.kd_ce_alpha = (
|
||||
args.kd_ce_alpha + (args.kd_ce_alpha_end - args.kd_ce_alpha) * scale
|
||||
)
|
||||
@@ -62,10 +62,16 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
Transform logprobs to target format for KD training
|
||||
"""
|
||||
|
||||
logprobs = sample.pop(self.logprobs_field)
|
||||
if "target_logprobs" in sample.keys() and "target_token_ids" in sample.keys():
|
||||
logprobs = sample.pop("target_logprobs")
|
||||
token_ids = sample.pop("target_token_ids")
|
||||
else:
|
||||
logprobs = sample.pop(self.logprobs_field)
|
||||
token_ids = [None] * len(logprobs)
|
||||
|
||||
target_seq_len = len(logprobs)
|
||||
input_seq_len = len(sample["input_ids"])
|
||||
input_padding_len = input_seq_len - target_seq_len
|
||||
target_padding_len = input_seq_len - target_seq_len
|
||||
# get non-zero top-k (prune None logprobs from vllm data step)
|
||||
top_k_vals = [
|
||||
len(logprobs[i])
|
||||
@@ -82,11 +88,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
target_token_ids = []
|
||||
target_mask = []
|
||||
|
||||
if input_padding_len < 0:
|
||||
if target_padding_len < 0:
|
||||
# logprobs is longer than target_seq_len,
|
||||
# so we need to slice from the left/beginning of logprobs
|
||||
logprobs = logprobs[:-input_seq_len]
|
||||
input_padding_len = 0
|
||||
target_padding_len = 0
|
||||
# target_seq_len = input_seq_len
|
||||
|
||||
# truncate the second dimension of the logprobs to top_k
|
||||
@@ -98,33 +104,37 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
|
||||
# otherwise, we need to shift in the trainer
|
||||
shift = 0
|
||||
for _ in range(shift, input_padding_len):
|
||||
for _ in range(shift, target_padding_len):
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
for position in range(input_padding_len, input_seq_len):
|
||||
for position in range(target_padding_len, input_seq_len):
|
||||
if sample["labels"][position] == -100:
|
||||
target_mask.append([0] * top_k)
|
||||
else:
|
||||
target_mask.append([1] * top_k)
|
||||
|
||||
for _, token_pos_logprobs in enumerate(logprobs):
|
||||
for token_pos_logprobs, token_pos_token_ids in zip(logprobs, token_ids):
|
||||
# Initialize collections for logprobs and token_ids
|
||||
position_logprobs = []
|
||||
position_token_ids = []
|
||||
|
||||
# Process each token probability entry
|
||||
for entry in token_pos_logprobs:
|
||||
# Extract logprob value
|
||||
logprob = entry["logprob"]
|
||||
if token_pos_token_ids is None:
|
||||
for entry in token_pos_logprobs:
|
||||
# Extract logprob value
|
||||
logprob = entry["logprob"]
|
||||
|
||||
# Parse token_id from the "token_id:###" format
|
||||
token_id = int(entry["token"].split(":")[1])
|
||||
# Parse token_id from the "token_id:###" format
|
||||
token_id = int(entry["token"].split(":")[1])
|
||||
|
||||
# Append to our collections
|
||||
position_logprobs.append(logprob)
|
||||
position_token_ids.append(token_id)
|
||||
# Append to our collections
|
||||
position_logprobs.append(logprob)
|
||||
position_token_ids.append(token_id)
|
||||
else:
|
||||
position_logprobs = token_pos_logprobs
|
||||
position_token_ids = token_pos_token_ids
|
||||
|
||||
# Convert to a tensor for easier manipulation
|
||||
position_logprobs_tensor = torch.tensor(
|
||||
@@ -143,6 +153,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
teacher_probs_t2 = teacher_probs_t1**exponent
|
||||
else:
|
||||
teacher_probs_t2 = teacher_probs_t1
|
||||
|
||||
# Re-normalize
|
||||
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
||||
dim=0, keepdim=True
|
||||
|
||||
@@ -16,17 +16,35 @@
|
||||
KD trainer
|
||||
"""
|
||||
|
||||
from transformers import TrainerControl
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
||||
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
||||
|
||||
|
||||
class AxolotlKDTrainerControl(TrainerControl):
|
||||
kd_alpha: float = 1.0
|
||||
kd_ce_alpha: float = 0.0
|
||||
|
||||
def state(self) -> dict:
|
||||
state_val = super().state()
|
||||
state_val["args"]["kd_alpha"] = self.kd_alpha
|
||||
state_val["args"]["kd_ce_alpha"] = self.kd_ce_alpha
|
||||
|
||||
|
||||
class AxolotlKDTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Custom trainer subclass for Knowledge Distillation (KD)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.kd_alpha = self.args.kd_alpha
|
||||
self.kd_ce_alpha = self.args.kd_ce_alpha
|
||||
self.control = AxolotlKDTrainerControl()
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
super()._set_signature_columns_if_needed()
|
||||
columns_to_add = []
|
||||
@@ -95,9 +113,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
|
||||
)
|
||||
|
||||
if self.args.kd_ce_alpha > 0:
|
||||
kd_alpha = self.args.kd_alpha
|
||||
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
||||
if self.kd_ce_alpha > 0:
|
||||
loss = self.kd_ce_alpha * outputs["loss"] + self.kd_alpha * loss_kd
|
||||
else:
|
||||
loss = loss_kd
|
||||
# Save past state if it exists
|
||||
|
||||
@@ -813,6 +813,15 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
)
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||
# TODO if using deepspeed and it's a file, save deepspeed config too
|
||||
if args.deepspeed and os.path.isfile(args.deepspeed):
|
||||
LOG.info(f"DeepSpeed config has been saved to the WandB run.")
|
||||
artifact = wandb.Artifact(
|
||||
f"deepspeed-{wandb.run.id}", type="deepspeed-config"
|
||||
)
|
||||
artifact.add_file(args.deepspeed)
|
||||
wandb.log_artifact(artifact)
|
||||
wandb.save(args.deepspeed)
|
||||
return control
|
||||
|
||||
|
||||
|
||||
@@ -173,10 +173,16 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
try:
|
||||
arrays = [
|
||||
np.array(item[feature])
|
||||
for item in features_
|
||||
if feature in item
|
||||
]
|
||||
if arrays[0].dtype != "object":
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
except ValueError:
|
||||
pass
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
|
||||
|
||||
|
||||
@@ -25,8 +25,8 @@ def fixture_cfg():
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"sequence_len": 2048,
|
||||
"rl": True,
|
||||
"adam_beta1": 0.998,
|
||||
"adam_beta2": 0.9,
|
||||
"adam_beta1": 0.91,
|
||||
"adam_beta2": 0.998,
|
||||
"adam_epsilon": 0.00001,
|
||||
"dataloader_num_workers": 1,
|
||||
"dataloader_pin_memory": True,
|
||||
@@ -60,8 +60,8 @@ class TestHFRLTrainerBuilder:
|
||||
def test_build_training_arguments(self, cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
||||
training_arguments = builder.build_training_arguments(100)
|
||||
assert training_arguments.adam_beta1 == 0.998
|
||||
assert training_arguments.adam_beta2 == 0.9
|
||||
assert training_arguments.adam_beta1 == 0.91
|
||||
assert training_arguments.adam_beta2 == 0.998
|
||||
assert training_arguments.adam_epsilon == 0.00001
|
||||
assert training_arguments.dataloader_num_workers == 1
|
||||
assert training_arguments.dataloader_pin_memory is True
|
||||
|
||||
Reference in New Issue
Block a user