Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
8fc4c420a4 Add kd coefficient scheduler 2025-03-18 09:01:58 -04:00
9 changed files with 112 additions and 26 deletions

View File

@@ -751,8 +751,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.kd_ce_alpha is not None: if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_ce_alpha_end is not None:
training_arguments_kwargs["kd_ce_alpha_end"] = self.cfg.kd_ce_alpha_end
if self.cfg.kd_alpha is not None: if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_alpha_end is not None:
training_arguments_kwargs["kd_alpha_end"] = self.cfg.kd_alpha_end
if self.cfg.kd_temperature is not None: if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None: if self.cfg.kd_zscore_base_temp is not None:

View File

@@ -34,3 +34,12 @@ class KDPlugin(BasePlugin):
return AxolotlKDTrainer return AxolotlKDTrainer
return None return None
def add_callbacks_post_trainer(self, cfg, trainer):
callbacks = []
if cfg.kd_trainer:
from .callbacks import KDAlphaSchedulerCallback
callbacks.append(KDAlphaSchedulerCallback())
return callbacks

View File

@@ -30,6 +30,8 @@ class KDArgs(BaseModel):
float float
] = None # loss coefficient for cross-entropy loss during KD ] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_ce_alpha_end: Optional[float] = None # end value for kd_ce_alpha
kd_alpha_end: Optional[float] = None # end value for kd_alpha
kd_temperature: Optional[float] = None # temperature for sampling during KD kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[ kd_top_k_before_softmax: Optional[

View File

@@ -0,0 +1,28 @@
from transformers import TrainerCallback
class KDAlphaSchedulerCallback(TrainerCallback):
"""Callback to for scheduling KD alpha during training."""
def on_epoch_begin(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if int(state.epoch) == 0:
state.kd_alpha = args.kd_alpha
state.kd_ce_alpha = args.kd_ce_alpha
elif int(state.epoch) == state.num_train_epochs - 1:
if args.kd_alpha_end is not None:
control.kd_alpha = args.kd_alpha_end
if args.kd_ce_alpha_end is not None:
control.kd_ce_alpha = args.kd_ce_alpha_end
else:
epoch_steps = state.num_train_epochs - 1
scale = int(state.epoch) / epoch_steps
if args.kd_alpha_end is not None:
control.kd_alpha = (
args.kd_alpha + (args.kd_alpha_end - args.kd_alpha) * scale
)
if args.kd_ce_alpha_end is not None:
control.kd_ce_alpha = (
args.kd_ce_alpha + (args.kd_ce_alpha_end - args.kd_ce_alpha) * scale
)

View File

@@ -62,10 +62,16 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
Transform logprobs to target format for KD training Transform logprobs to target format for KD training
""" """
logprobs = sample.pop(self.logprobs_field) if "target_logprobs" in sample.keys() and "target_token_ids" in sample.keys():
logprobs = sample.pop("target_logprobs")
token_ids = sample.pop("target_token_ids")
else:
logprobs = sample.pop(self.logprobs_field)
token_ids = [None] * len(logprobs)
target_seq_len = len(logprobs) target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"]) input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len target_padding_len = input_seq_len - target_seq_len
# get non-zero top-k (prune None logprobs from vllm data step) # get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [ top_k_vals = [
len(logprobs[i]) len(logprobs[i])
@@ -82,11 +88,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_token_ids = [] target_token_ids = []
target_mask = [] target_mask = []
if input_padding_len < 0: if target_padding_len < 0:
# logprobs is longer than target_seq_len, # logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs # so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len] logprobs = logprobs[:-input_seq_len]
input_padding_len = 0 target_padding_len = 0
# target_seq_len = input_seq_len # target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k # truncate the second dimension of the logprobs to top_k
@@ -98,33 +104,37 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# for causal models, if we start the range at 1, then we don't need to shift in the trainer # for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer # otherwise, we need to shift in the trainer
shift = 0 shift = 0
for _ in range(shift, input_padding_len): for _ in range(shift, target_padding_len):
target_logprobs.append([-float("inf")] * top_k) target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k))) target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k) target_mask.append([0] * top_k)
for position in range(input_padding_len, input_seq_len): for position in range(target_padding_len, input_seq_len):
if sample["labels"][position] == -100: if sample["labels"][position] == -100:
target_mask.append([0] * top_k) target_mask.append([0] * top_k)
else: else:
target_mask.append([1] * top_k) target_mask.append([1] * top_k)
for _, token_pos_logprobs in enumerate(logprobs): for token_pos_logprobs, token_pos_token_ids in zip(logprobs, token_ids):
# Initialize collections for logprobs and token_ids # Initialize collections for logprobs and token_ids
position_logprobs = [] position_logprobs = []
position_token_ids = [] position_token_ids = []
# Process each token probability entry # Process each token probability entry
for entry in token_pos_logprobs: if token_pos_token_ids is None:
# Extract logprob value for entry in token_pos_logprobs:
logprob = entry["logprob"] # Extract logprob value
logprob = entry["logprob"]
# Parse token_id from the "token_id:###" format # Parse token_id from the "token_id:###" format
token_id = int(entry["token"].split(":")[1]) token_id = int(entry["token"].split(":")[1])
# Append to our collections # Append to our collections
position_logprobs.append(logprob) position_logprobs.append(logprob)
position_token_ids.append(token_id) position_token_ids.append(token_id)
else:
position_logprobs = token_pos_logprobs
position_token_ids = token_pos_token_ids
# Convert to a tensor for easier manipulation # Convert to a tensor for easier manipulation
position_logprobs_tensor = torch.tensor( position_logprobs_tensor = torch.tensor(
@@ -143,6 +153,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
teacher_probs_t2 = teacher_probs_t1**exponent teacher_probs_t2 = teacher_probs_t1**exponent
else: else:
teacher_probs_t2 = teacher_probs_t1 teacher_probs_t2 = teacher_probs_t1
# Re-normalize # Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum( teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True dim=0, keepdim=True

View File

@@ -16,17 +16,35 @@
KD trainer KD trainer
""" """
from transformers import TrainerControl
from axolotl.core.trainers.base import AxolotlTrainer from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import loss as topk_kd_loss from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
class AxolotlKDTrainerControl(TrainerControl):
kd_alpha: float = 1.0
kd_ce_alpha: float = 0.0
def state(self) -> dict:
state_val = super().state()
state_val["args"]["kd_alpha"] = self.kd_alpha
state_val["args"]["kd_ce_alpha"] = self.kd_ce_alpha
class AxolotlKDTrainer(AxolotlTrainer): class AxolotlKDTrainer(AxolotlTrainer):
""" """
Custom trainer subclass for Knowledge Distillation (KD) Custom trainer subclass for Knowledge Distillation (KD)
""" """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.kd_alpha = self.args.kd_alpha
self.kd_ce_alpha = self.args.kd_ce_alpha
self.control = AxolotlKDTrainerControl()
def _set_signature_columns_if_needed(self): def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed() super()._set_signature_columns_if_needed()
columns_to_add = [] columns_to_add = []
@@ -95,9 +113,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0, top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
) )
if self.args.kd_ce_alpha > 0: if self.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha loss = self.kd_ce_alpha * outputs["loss"] + self.kd_alpha * loss_kd
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else: else:
loss = loss_kd loss = loss_kd
# Save past state if it exists # Save past state if it exists

View File

@@ -813,6 +813,15 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
) )
except (FileNotFoundError, ConnectionError) as err: except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}") LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
# TODO if using deepspeed and it's a file, save deepspeed config too
if args.deepspeed and os.path.isfile(args.deepspeed):
LOG.info(f"DeepSpeed config has been saved to the WandB run.")
artifact = wandb.Artifact(
f"deepspeed-{wandb.run.id}", type="deepspeed-config"
)
artifact.add_file(args.deepspeed)
wandb.log_artifact(artifact)
wandb.save(args.deepspeed)
return control return control

View File

@@ -173,10 +173,16 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
] ]
out_features[i][feature] = np.concatenate(arrays) out_features[i][feature] = np.concatenate(arrays)
else: else:
arrays = [ try:
np.array(item[feature]) for item in features_ if feature in item arrays = [
] np.array(item[feature])
out_features[i][feature] = np.concatenate(arrays) for item in features_
if feature in item
]
if arrays[0].dtype != "object":
out_features[i][feature] = np.concatenate(arrays)
except ValueError:
pass
return super().__call__(out_features, return_tensors=return_tensors) return super().__call__(out_features, return_tensors=return_tensors)

View File

@@ -25,8 +25,8 @@ def fixture_cfg():
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"sequence_len": 2048, "sequence_len": 2048,
"rl": True, "rl": True,
"adam_beta1": 0.998, "adam_beta1": 0.91,
"adam_beta2": 0.9, "adam_beta2": 0.998,
"adam_epsilon": 0.00001, "adam_epsilon": 0.00001,
"dataloader_num_workers": 1, "dataloader_num_workers": 1,
"dataloader_pin_memory": True, "dataloader_pin_memory": True,
@@ -60,8 +60,8 @@ class TestHFRLTrainerBuilder:
def test_build_training_arguments(self, cfg, model, tokenizer): def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFRLTrainerBuilder(cfg, model, tokenizer) builder = HFRLTrainerBuilder(cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100) training_arguments = builder.build_training_arguments(100)
assert training_arguments.adam_beta1 == 0.998 assert training_arguments.adam_beta1 == 0.91
assert training_arguments.adam_beta2 == 0.9 assert training_arguments.adam_beta2 == 0.998
assert training_arguments.adam_epsilon == 0.00001 assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.dataloader_num_workers == 1 assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True assert training_arguments.dataloader_pin_memory is True