From 8fc4c420a418aa31c2df10ad7eaebb60e1f53650 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 18 Mar 2025 09:01:58 -0400 Subject: [PATCH] Add kd coefficient scheduler --- src/axolotl/core/trainer_builder.py | 4 ++ src/axolotl/integrations/kd/__init__.py | 9 +++++ src/axolotl/integrations/kd/args.py | 2 + src/axolotl/integrations/kd/callbacks.py | 28 +++++++++++++ src/axolotl/integrations/kd/chat_template.py | 41 +++++++++++++------- src/axolotl/integrations/kd/trainer.py | 23 +++++++++-- src/axolotl/utils/callbacks/__init__.py | 9 +++++ src/axolotl/utils/collators/batching.py | 14 +++++-- tests/core/test_trainer_builder.py | 8 ++-- 9 files changed, 112 insertions(+), 26 deletions(-) create mode 100644 src/axolotl/integrations/kd/callbacks.py diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 0c9204747..aeeb8b270 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -751,8 +751,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.kd_ce_alpha is not None: training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha + if self.cfg.kd_ce_alpha_end is not None: + training_arguments_kwargs["kd_ce_alpha_end"] = self.cfg.kd_ce_alpha_end if self.cfg.kd_alpha is not None: training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha + if self.cfg.kd_alpha_end is not None: + training_arguments_kwargs["kd_alpha_end"] = self.cfg.kd_alpha_end if self.cfg.kd_temperature is not None: training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature if self.cfg.kd_zscore_base_temp is not None: diff --git a/src/axolotl/integrations/kd/__init__.py b/src/axolotl/integrations/kd/__init__.py index 8a6e3eda1..09b6adace 100644 --- a/src/axolotl/integrations/kd/__init__.py +++ b/src/axolotl/integrations/kd/__init__.py @@ -34,3 +34,12 @@ class KDPlugin(BasePlugin): return AxolotlKDTrainer return None + + def add_callbacks_post_trainer(self, cfg, trainer): + callbacks = [] + if cfg.kd_trainer: + from .callbacks import KDAlphaSchedulerCallback + + callbacks.append(KDAlphaSchedulerCallback()) + + return callbacks diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py index a88a0dc48..57fe43f9e 100644 --- a/src/axolotl/integrations/kd/args.py +++ b/src/axolotl/integrations/kd/args.py @@ -30,6 +30,8 @@ class KDArgs(BaseModel): float ] = None # loss coefficient for cross-entropy loss during KD kd_alpha: Optional[float] = None # loss coefficient for KD loss + kd_ce_alpha_end: Optional[float] = None # end value for kd_ce_alpha + kd_alpha_end: Optional[float] = None # end value for kd_alpha kd_temperature: Optional[float] = None # temperature for sampling during KD kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling kd_top_k_before_softmax: Optional[ diff --git a/src/axolotl/integrations/kd/callbacks.py b/src/axolotl/integrations/kd/callbacks.py new file mode 100644 index 000000000..81b82deef --- /dev/null +++ b/src/axolotl/integrations/kd/callbacks.py @@ -0,0 +1,28 @@ +from transformers import TrainerCallback + + +class KDAlphaSchedulerCallback(TrainerCallback): + """Callback to for scheduling KD alpha during training.""" + + def on_epoch_begin( + self, args, state, control, **kwargs # pylint: disable=unused-argument + ): + if int(state.epoch) == 0: + state.kd_alpha = args.kd_alpha + state.kd_ce_alpha = args.kd_ce_alpha + elif int(state.epoch) == state.num_train_epochs - 1: + if args.kd_alpha_end is not None: + control.kd_alpha = args.kd_alpha_end + if args.kd_ce_alpha_end is not None: + control.kd_ce_alpha = args.kd_ce_alpha_end + else: + epoch_steps = state.num_train_epochs - 1 + scale = int(state.epoch) / epoch_steps + if args.kd_alpha_end is not None: + control.kd_alpha = ( + args.kd_alpha + (args.kd_alpha_end - args.kd_alpha) * scale + ) + if args.kd_ce_alpha_end is not None: + control.kd_ce_alpha = ( + args.kd_ce_alpha + (args.kd_ce_alpha_end - args.kd_ce_alpha) * scale + ) diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index 699728e9f..7efe97e63 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -62,10 +62,16 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): Transform logprobs to target format for KD training """ - logprobs = sample.pop(self.logprobs_field) + if "target_logprobs" in sample.keys() and "target_token_ids" in sample.keys(): + logprobs = sample.pop("target_logprobs") + token_ids = sample.pop("target_token_ids") + else: + logprobs = sample.pop(self.logprobs_field) + token_ids = [None] * len(logprobs) + target_seq_len = len(logprobs) input_seq_len = len(sample["input_ids"]) - input_padding_len = input_seq_len - target_seq_len + target_padding_len = input_seq_len - target_seq_len # get non-zero top-k (prune None logprobs from vllm data step) top_k_vals = [ len(logprobs[i]) @@ -82,11 +88,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): target_token_ids = [] target_mask = [] - if input_padding_len < 0: + if target_padding_len < 0: # logprobs is longer than target_seq_len, # so we need to slice from the left/beginning of logprobs logprobs = logprobs[:-input_seq_len] - input_padding_len = 0 + target_padding_len = 0 # target_seq_len = input_seq_len # truncate the second dimension of the logprobs to top_k @@ -98,33 +104,37 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): # for causal models, if we start the range at 1, then we don't need to shift in the trainer # otherwise, we need to shift in the trainer shift = 0 - for _ in range(shift, input_padding_len): + for _ in range(shift, target_padding_len): target_logprobs.append([-float("inf")] * top_k) target_token_ids.append(list(range(top_k))) target_mask.append([0] * top_k) - for position in range(input_padding_len, input_seq_len): + for position in range(target_padding_len, input_seq_len): if sample["labels"][position] == -100: target_mask.append([0] * top_k) else: target_mask.append([1] * top_k) - for _, token_pos_logprobs in enumerate(logprobs): + for token_pos_logprobs, token_pos_token_ids in zip(logprobs, token_ids): # Initialize collections for logprobs and token_ids position_logprobs = [] position_token_ids = [] # Process each token probability entry - for entry in token_pos_logprobs: - # Extract logprob value - logprob = entry["logprob"] + if token_pos_token_ids is None: + for entry in token_pos_logprobs: + # Extract logprob value + logprob = entry["logprob"] - # Parse token_id from the "token_id:###" format - token_id = int(entry["token"].split(":")[1]) + # Parse token_id from the "token_id:###" format + token_id = int(entry["token"].split(":")[1]) - # Append to our collections - position_logprobs.append(logprob) - position_token_ids.append(token_id) + # Append to our collections + position_logprobs.append(logprob) + position_token_ids.append(token_id) + else: + position_logprobs = token_pos_logprobs + position_token_ids = token_pos_token_ids # Convert to a tensor for easier manipulation position_logprobs_tensor = torch.tensor( @@ -143,6 +153,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): teacher_probs_t2 = teacher_probs_t1**exponent else: teacher_probs_t2 = teacher_probs_t1 + # Re-normalize teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum( dim=0, keepdim=True diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index f99f2ca28..c68e8907c 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -16,17 +16,35 @@ KD trainer """ +from transformers import TrainerControl + from axolotl.core.trainers.base import AxolotlTrainer from .topk_logprob.forward_kl import loss as topk_kd_loss from .topk_logprob.forward_kl import topk_kd_loss_with_zscore +class AxolotlKDTrainerControl(TrainerControl): + kd_alpha: float = 1.0 + kd_ce_alpha: float = 0.0 + + def state(self) -> dict: + state_val = super().state() + state_val["args"]["kd_alpha"] = self.kd_alpha + state_val["args"]["kd_ce_alpha"] = self.kd_ce_alpha + + class AxolotlKDTrainer(AxolotlTrainer): """ Custom trainer subclass for Knowledge Distillation (KD) """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.kd_alpha = self.args.kd_alpha + self.kd_ce_alpha = self.args.kd_ce_alpha + self.control = AxolotlKDTrainerControl() + def _set_signature_columns_if_needed(self): super()._set_signature_columns_if_needed() columns_to_add = [] @@ -95,9 +113,8 @@ class AxolotlKDTrainer(AxolotlTrainer): top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0, ) - if self.args.kd_ce_alpha > 0: - kd_alpha = self.args.kd_alpha - loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd + if self.kd_ce_alpha > 0: + loss = self.kd_ce_alpha * outputs["loss"] + self.kd_alpha * loss_kd else: loss = loss_kd # Save past state if it exists diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 9ca0e84fe..d487231a7 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -813,6 +813,15 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback): ) except (FileNotFoundError, ConnectionError) as err: LOG.warning(f"Error while saving Axolotl config to WandB: {err}") + # TODO if using deepspeed and it's a file, save deepspeed config too + if args.deepspeed and os.path.isfile(args.deepspeed): + LOG.info(f"DeepSpeed config has been saved to the WandB run.") + artifact = wandb.Artifact( + f"deepspeed-{wandb.run.id}", type="deepspeed-config" + ) + artifact.add_file(args.deepspeed) + wandb.log_artifact(artifact) + wandb.save(args.deepspeed) return control diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 7cf771421..c1f7809c0 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -173,10 +173,16 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ] out_features[i][feature] = np.concatenate(arrays) else: - arrays = [ - np.array(item[feature]) for item in features_ if feature in item - ] - out_features[i][feature] = np.concatenate(arrays) + try: + arrays = [ + np.array(item[feature]) + for item in features_ + if feature in item + ] + if arrays[0].dtype != "object": + out_features[i][feature] = np.concatenate(arrays) + except ValueError: + pass return super().__call__(out_features, return_tensors=return_tensors) diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py index fbfd7a87c..508ff01a0 100644 --- a/tests/core/test_trainer_builder.py +++ b/tests/core/test_trainer_builder.py @@ -25,8 +25,8 @@ def fixture_cfg(): "optimizer": "adamw_torch_fused", "sequence_len": 2048, "rl": True, - "adam_beta1": 0.998, - "adam_beta2": 0.9, + "adam_beta1": 0.91, + "adam_beta2": 0.998, "adam_epsilon": 0.00001, "dataloader_num_workers": 1, "dataloader_pin_memory": True, @@ -60,8 +60,8 @@ class TestHFRLTrainerBuilder: def test_build_training_arguments(self, cfg, model, tokenizer): builder = HFRLTrainerBuilder(cfg, model, tokenizer) training_arguments = builder.build_training_arguments(100) - assert training_arguments.adam_beta1 == 0.998 - assert training_arguments.adam_beta2 == 0.9 + assert training_arguments.adam_beta1 == 0.91 + assert training_arguments.adam_beta2 == 0.998 assert training_arguments.adam_epsilon == 0.00001 assert training_arguments.dataloader_num_workers == 1 assert training_arguments.dataloader_pin_memory is True