Compare commits
36 Commits
mistral-su
...
kd-fix-202
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2491303c46 | ||
|
|
2c66483a47 | ||
|
|
01382b9a79 | ||
|
|
cfcd69df0d | ||
|
|
2302b14a84 | ||
|
|
a8e2bddd19 | ||
|
|
d55a51623f | ||
|
|
73a84ad0dd | ||
|
|
3cffe881bb | ||
|
|
e77d62933d | ||
|
|
3a0faa97ca | ||
|
|
20602fd93f | ||
|
|
770bb0605a | ||
|
|
24b96b1c4f | ||
|
|
90c7228ff9 | ||
|
|
9eb53f5c9e | ||
|
|
225b420dc5 | ||
|
|
b75db13615 | ||
|
|
c7b1db329e | ||
|
|
a40e484803 | ||
|
|
9899c924f9 | ||
|
|
505009b454 | ||
|
|
b4e96ef12c | ||
|
|
a8d9fab635 | ||
|
|
49e2fa825d | ||
|
|
7263845207 | ||
|
|
5ccfd225cb | ||
|
|
28eb8632a1 | ||
|
|
5cfaac3767 | ||
|
|
ca70fb7cb0 | ||
|
|
22b50d6619 | ||
|
|
a2248673d8 | ||
|
|
0399aefcb3 | ||
|
|
83ad248e5b | ||
|
|
6fafe46562 | ||
|
|
0e46367e01 |
31
deepspeed_configs/zero2_torch_compile.json
Normal file
31
deepspeed_configs/zero2_torch_compile.json
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
{
|
||||||
|
"compile": {
|
||||||
|
"disable": false,
|
||||||
|
"backend": "inductor"
|
||||||
|
},
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 2,
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "cpu"
|
||||||
|
},
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"overlap_comm": true
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"auto_cast": false,
|
||||||
|
"loss_scale": 0,
|
||||||
|
"initial_scale_power": 32,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"gradient_clipping": "auto",
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
||||||
@@ -20,7 +20,6 @@ datasets==3.6.0
|
|||||||
deepspeed>=0.17.0
|
deepspeed>=0.17.0
|
||||||
trl==0.18.1
|
trl==0.18.1
|
||||||
hf_xet==1.1.2
|
hf_xet==1.1.2
|
||||||
mistral-common[hf-hub]==1.6.0
|
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
|
|||||||
@@ -305,8 +305,8 @@ def load_model_and_tokenizer(
|
|||||||
ProcessorMixin | None,
|
ProcessorMixin | None,
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
Helper function for loading a model, tokenizer, and processor specified in the
|
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
|
||||||
given `axolotl` config.
|
config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
|||||||
@@ -21,11 +21,6 @@ from axolotl.core.trainers import (
|
|||||||
AxolotlTrainer,
|
AxolotlTrainer,
|
||||||
ReLoRATrainer,
|
ReLoRATrainer,
|
||||||
)
|
)
|
||||||
from axolotl.core.training_args import (
|
|
||||||
AxolotlPRMConfig,
|
|
||||||
AxolotlRewardConfig,
|
|
||||||
AxolotlTrainingArguments,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||||
@@ -130,6 +125,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def _get_trainer_cls(self):
|
def _get_trainer_cls(self):
|
||||||
|
"""
|
||||||
|
Gets the trainer class for the given configuration.
|
||||||
|
"""
|
||||||
if self.cfg.plugins:
|
if self.cfg.plugins:
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||||
@@ -146,6 +144,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
|
from axolotl.core.training_args import (
|
||||||
|
AxolotlPRMConfig,
|
||||||
|
AxolotlRewardConfig,
|
||||||
|
AxolotlTrainingArguments,
|
||||||
|
)
|
||||||
|
|
||||||
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
|
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
|
||||||
total_num_steps
|
total_num_steps
|
||||||
)
|
)
|
||||||
@@ -314,20 +318,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["image_resize_algorithm"] = (
|
training_arguments_kwargs["image_resize_algorithm"] = (
|
||||||
self.cfg.image_resize_algorithm
|
self.cfg.image_resize_algorithm
|
||||||
)
|
)
|
||||||
if self.cfg.kd_ce_alpha is not None:
|
|
||||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
if self.cfg.plugins:
|
||||||
if self.cfg.kd_alpha is not None:
|
plugin_manager = PluginManager.get_instance()
|
||||||
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
|
plugin_training_args = plugin_manager.get_training_args(self.cfg)
|
||||||
if self.cfg.kd_temperature is not None:
|
if plugin_training_args:
|
||||||
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
training_arguments_kwargs.update(plugin_training_args)
|
||||||
if self.cfg.kd_zscore_base_temp is not None:
|
|
||||||
training_arguments_kwargs["kd_zscore_base_temp"] = (
|
|
||||||
self.cfg.kd_zscore_base_temp
|
|
||||||
)
|
|
||||||
if self.cfg.kd_top_k_before_softmax is not None:
|
|
||||||
training_arguments_kwargs["kd_top_k_before_softmax"] = (
|
|
||||||
self.cfg.kd_top_k_before_softmax
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
@@ -408,7 +404,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
def build_collator(
|
def build_collator(
|
||||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
self,
|
||||||
|
training_args, # type: "AxolotlTrainingArguments" # type: ignore
|
||||||
|
is_eval=False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
if (
|
if (
|
||||||
@@ -437,7 +436,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
collator_args = [self.tokenizer]
|
collator_args = [self.tokenizer]
|
||||||
if self.cfg.reward_model:
|
|
||||||
|
if self.cfg.plugins:
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs(
|
||||||
|
self.cfg, is_eval=is_eval
|
||||||
|
)
|
||||||
|
|
||||||
|
if collator_cls_and_kwargs:
|
||||||
|
collator = collator_cls_and_kwargs[0]
|
||||||
|
if kwargs and isinstance(kwargs, dict):
|
||||||
|
kwargs.update(collator_cls_and_kwargs[1])
|
||||||
|
elif self.cfg.reward_model:
|
||||||
collator = RewardDataCollatorWithPadding
|
collator = RewardDataCollatorWithPadding
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
||||||
@@ -468,16 +478,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
collator_args.pop(0)
|
collator_args.pop(0)
|
||||||
kwargs.pop("pad_to_multiple_of", None)
|
kwargs.pop("pad_to_multiple_of", None)
|
||||||
kwargs.pop("padding", None)
|
kwargs.pop("padding", None)
|
||||||
elif self.cfg.kd_trainer:
|
|
||||||
from axolotl.integrations.kd.collator import (
|
|
||||||
DataCollatorForKD,
|
|
||||||
KDBatchSamplerDataCollatorForSeq2Seq,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.sample_packing:
|
|
||||||
collator = KDBatchSamplerDataCollatorForSeq2Seq
|
|
||||||
else:
|
|
||||||
collator = DataCollatorForKD
|
|
||||||
else:
|
else:
|
||||||
collator = DataCollatorForSeq2Seq
|
collator = DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
|||||||
@@ -12,11 +12,6 @@ from axolotl.core.trainers import (
|
|||||||
from axolotl.core.trainers.dpo import DPOStrategy
|
from axolotl.core.trainers.dpo import DPOStrategy
|
||||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||||
from axolotl.core.training_args import (
|
|
||||||
AxolotlCPOConfig,
|
|
||||||
AxolotlKTOConfig,
|
|
||||||
AxolotlORPOConfig,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.loaders.utils import ensure_dtype
|
from axolotl.loaders.utils import ensure_dtype
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -79,6 +74,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
"""
|
"""
|
||||||
Returns training_args and trainer_kwargs
|
Returns training_args and trainer_kwargs
|
||||||
"""
|
"""
|
||||||
|
from axolotl.core.training_args import (
|
||||||
|
AxolotlCPOConfig,
|
||||||
|
AxolotlKTOConfig,
|
||||||
|
AxolotlORPOConfig,
|
||||||
|
)
|
||||||
|
|
||||||
training_args_kwargs, trainer_kwargs = self._set_base_training_args(
|
training_args_kwargs, trainer_kwargs = self._set_base_training_args(
|
||||||
total_num_steps=total_num_steps
|
total_num_steps=total_num_steps
|
||||||
)
|
)
|
||||||
@@ -165,6 +166,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if blocklist_key in training_args_kwargs:
|
if blocklist_key in training_args_kwargs:
|
||||||
del training_args_kwargs[blocklist_key]
|
del training_args_kwargs[blocklist_key]
|
||||||
|
|
||||||
|
|
||||||
|
if self.cfg.plugins:
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
plugin_training_args = plugin_manager.get_training_args(self.cfg)
|
||||||
|
if plugin_training_args:
|
||||||
|
training_args_kwargs.update(plugin_training_args)
|
||||||
|
|
||||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||||
logging_first_step=True,
|
logging_first_step=True,
|
||||||
**training_args_kwargs,
|
**training_args_kwargs,
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from axolotl.core.trainers.utils import (
|
|||||||
sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils import get_not_null
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
@@ -101,7 +102,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
|||||||
)
|
)
|
||||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||||
|
|
||||||
return MultipackBatchSampler(
|
sampler = MultipackBatchSampler(
|
||||||
base_sampler,
|
base_sampler,
|
||||||
lengths=get_dataset_lengths(dataset),
|
lengths=get_dataset_lengths(dataset),
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
@@ -113,6 +114,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
|||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
len(sampler)
|
||||||
|
return sampler
|
||||||
|
|
||||||
def _get_train_sampler(
|
def _get_train_sampler(
|
||||||
self, train_dataset: Optional[Dataset] = None
|
self, train_dataset: Optional[Dataset] = None
|
||||||
) -> Optional[Sampler]:
|
) -> Optional[Sampler]:
|
||||||
@@ -220,7 +224,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = get_not_null(
|
||||||
|
self.args.dataloader_drop_last, True
|
||||||
|
)
|
||||||
if sampler_fn is not None:
|
if sampler_fn is not None:
|
||||||
sampler = sampler_fn(dataset)
|
sampler = sampler_fn(dataset)
|
||||||
if isinstance(sampler, BatchSampler):
|
if isinstance(sampler, BatchSampler):
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from functools import partial
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
@@ -59,42 +58,6 @@ class AxolotlGRPOTrainer(
|
|||||||
|
|
||||||
_tag_names = ["trl", "grpo", "axolotl"]
|
_tag_names = ["trl", "grpo", "axolotl"]
|
||||||
|
|
||||||
def get_train_dataloader(self):
|
|
||||||
if self.train_dataset is None:
|
|
||||||
raise ValueError("Trainer: training requires a train_dataset.")
|
|
||||||
|
|
||||||
train_dataset = self.train_dataset
|
|
||||||
data_collator = self.data_collator
|
|
||||||
if isinstance(train_dataset, datasets.Dataset):
|
|
||||||
train_dataset = self._remove_unused_columns(
|
|
||||||
train_dataset, description="training"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
data_collator = self._get_collator_with_removed_columns(
|
|
||||||
data_collator, description="training"
|
|
||||||
)
|
|
||||||
|
|
||||||
dataloader_params = {
|
|
||||||
"batch_size": self._train_batch_size
|
|
||||||
* self.args.steps_per_generation, # < this is the change
|
|
||||||
"collate_fn": data_collator,
|
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
|
||||||
"persistent_workers": self.args.dataloader_persistent_workers,
|
|
||||||
}
|
|
||||||
|
|
||||||
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
|
||||||
dataloader_params["sampler"] = self._get_train_sampler()
|
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
|
||||||
dataloader_params["worker_init_fn"] = partial(
|
|
||||||
seed_worker,
|
|
||||||
num_workers=self.args.dataloader_num_workers,
|
|
||||||
rank=self.args.process_index,
|
|
||||||
)
|
|
||||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
|
||||||
|
|
||||||
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||||
|
|||||||
@@ -2,238 +2,17 @@
|
|||||||
extra axolotl specific training args
|
extra axolotl specific training args
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from __future__ import annotations
|
||||||
from typing import Optional
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
from PIL.Image import Resampling
|
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||||
|
|
||||||
|
from axolotl.integrations.config import merge_training_args
|
||||||
|
|
||||||
@dataclass
|
AxolotlTrainingMixins: Type = merge_training_args()
|
||||||
class AxolotlTrainingMixins:
|
|
||||||
"""
|
|
||||||
Mixin class for the Axolotl training args.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
model_type: Optional[str] = field(
|
|
||||||
default=None, metadata={"help": "HF model configuration model_type."}
|
|
||||||
)
|
|
||||||
lr_quadratic_warmup: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
|
||||||
)
|
|
||||||
pretraining: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Indicates to trainer whether we are doing continued pretraining."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
sample_packing: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Use sample packing for efficient training."},
|
|
||||||
)
|
|
||||||
sample_packing_sequentially: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
multipack_real_batches: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Use real batches for efficient training."},
|
|
||||||
)
|
|
||||||
eval_sample_packing: Optional[bool] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Use sample packing for efficient evals."},
|
|
||||||
)
|
|
||||||
sample_packing_efficiency: float = field(
|
|
||||||
default=1.0,
|
|
||||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
|
||||||
)
|
|
||||||
sample_packing_bin_size: int = field(
|
|
||||||
default=200,
|
|
||||||
metadata={
|
|
||||||
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
sample_packing_group_size: int = field(
|
|
||||||
default=100000,
|
|
||||||
metadata={
|
|
||||||
"help": "The number of samples to group together for packing. Increase for better packing."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
max_seq_length: int = field(
|
|
||||||
default=2048,
|
|
||||||
metadata={"help": "The maximum sequence length the model can handle"},
|
|
||||||
)
|
|
||||||
relora_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how often to reset for ReLoRA"},
|
|
||||||
)
|
|
||||||
relora_warmup_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
|
||||||
)
|
|
||||||
relora_anneal_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
|
||||||
)
|
|
||||||
relora_prune_ratio: Optional[float] = field(
|
|
||||||
default=0.9,
|
|
||||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
|
||||||
)
|
|
||||||
bench_split: Optional[str] = field(
|
|
||||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
|
||||||
)
|
|
||||||
bench_dataset: Optional[str] = field(
|
|
||||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
|
||||||
metadata={
|
|
||||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
do_bench_eval: Optional[bool] = field(
|
|
||||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
|
||||||
)
|
|
||||||
do_causal_lm_eval: Optional[bool] = field(
|
|
||||||
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
|
||||||
)
|
|
||||||
max_bench_samples: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
bench_source_max_len: int = field(
|
|
||||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
|
||||||
)
|
|
||||||
dataloader_prefetch_factor: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
|
||||||
)
|
|
||||||
cosine_min_lr_ratio: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
|
||||||
)
|
|
||||||
cosine_constant_lr_ratio: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
loraplus_lr_ratio: Optional[float] = field(
|
|
||||||
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
|
||||||
)
|
|
||||||
loraplus_lr_embedding: Optional[float] = field(
|
|
||||||
default=1e-6,
|
|
||||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
|
||||||
)
|
|
||||||
embedding_lr_scale: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
|
||||||
)
|
|
||||||
lr_groups: Optional[list[dict]] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Specify learning rate groups for with different LRs."},
|
|
||||||
)
|
|
||||||
embedding_lr: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
|
||||||
)
|
|
||||||
qlora: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "whether this is a qlora training"},
|
|
||||||
)
|
|
||||||
orpo_alpha: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
lisa_n_layers: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "the number of activate layers in LISA"},
|
|
||||||
)
|
|
||||||
lisa_step_interval: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how often to switch layers in LISA"},
|
|
||||||
)
|
|
||||||
lisa_layers_attribute: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "path under the model to access the layers"},
|
|
||||||
)
|
|
||||||
curriculum_sampling: Optional[bool] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
|
||||||
)
|
|
||||||
alternate_lr_scheduler_type: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
chat_template: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Chat template converting chat messages to text"},
|
|
||||||
)
|
|
||||||
|
|
||||||
kd_ce_alpha: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
kd_alpha: Optional[float] = field(
|
|
||||||
default=1.0,
|
|
||||||
metadata={"help": "The alpha scaling parameter for KD loss"},
|
|
||||||
)
|
|
||||||
|
|
||||||
kd_temperature: Optional[float] = field(
|
|
||||||
default=1.0,
|
|
||||||
metadata={
|
|
||||||
"help": "the temperature parameter for KL divergence loss when using KD"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
kd_zscore_base_temp: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "the base temperature parameter for KL divergence with z-score when using KD"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
kd_top_k_before_softmax: Optional[bool] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
adam_beta3: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
adam_epsilon2: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# multi-modal section
|
|
||||||
|
|
||||||
image_size: int | tuple[int, int] | None = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The size of the image to resize to"},
|
|
||||||
)
|
|
||||||
|
|
||||||
image_resize_algorithm: Resampling | None = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The algorithm to use for image resizing"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# end of multi-modal section
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
220
src/axolotl/core/training_args_base.py
Normal file
220
src/axolotl/core/training_args_base.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""
|
||||||
|
Base Axolotl Training Mixins shared across various trainer configs
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from PIL.Image import Resampling
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlTrainingMixins:
|
||||||
|
"""
|
||||||
|
Mixin class for the Axolotl training args.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
model_type: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "HF model configuration model_type."}
|
||||||
|
)
|
||||||
|
lr_quadratic_warmup: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||||
|
)
|
||||||
|
pretraining: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Indicates to trainer whether we are doing continued pretraining."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sample_packing: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use sample packing for efficient training."},
|
||||||
|
)
|
||||||
|
sample_packing_sequentially: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
multipack_real_batches: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use real batches for efficient training."},
|
||||||
|
)
|
||||||
|
eval_sample_packing: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Use sample packing for efficient evals."},
|
||||||
|
)
|
||||||
|
sample_packing_efficiency: float = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||||
|
)
|
||||||
|
sample_packing_bin_size: int = field(
|
||||||
|
default=200,
|
||||||
|
metadata={
|
||||||
|
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sample_packing_group_size: int = field(
|
||||||
|
default=100000,
|
||||||
|
metadata={
|
||||||
|
"help": "The number of samples to group together for packing. Increase for better packing."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
max_seq_length: int = field(
|
||||||
|
default=2048,
|
||||||
|
metadata={"help": "The maximum sequence length the model can handle"},
|
||||||
|
)
|
||||||
|
relora_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "how often to reset for ReLoRA"},
|
||||||
|
)
|
||||||
|
relora_warmup_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||||
|
)
|
||||||
|
relora_anneal_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||||
|
)
|
||||||
|
relora_prune_ratio: Optional[float] = field(
|
||||||
|
default=0.9,
|
||||||
|
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||||
|
)
|
||||||
|
bench_split: Optional[str] = field(
|
||||||
|
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||||
|
)
|
||||||
|
bench_dataset: Optional[str] = field(
|
||||||
|
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||||
|
metadata={
|
||||||
|
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
do_bench_eval: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||||
|
)
|
||||||
|
do_causal_lm_eval: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
||||||
|
)
|
||||||
|
max_bench_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
bench_source_max_len: int = field(
|
||||||
|
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||||
|
)
|
||||||
|
dataloader_prefetch_factor: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||||
|
)
|
||||||
|
cosine_min_lr_ratio: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
||||||
|
)
|
||||||
|
cosine_constant_lr_ratio: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
loraplus_lr_ratio: Optional[float] = field(
|
||||||
|
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
||||||
|
)
|
||||||
|
loraplus_lr_embedding: Optional[float] = field(
|
||||||
|
default=1e-6,
|
||||||
|
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||||
|
)
|
||||||
|
embedding_lr_scale: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||||
|
)
|
||||||
|
lr_groups: Optional[list[dict]] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Specify learning rate groups for with different LRs."},
|
||||||
|
)
|
||||||
|
embedding_lr: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||||
|
)
|
||||||
|
qlora: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "whether this is a qlora training"},
|
||||||
|
)
|
||||||
|
orpo_alpha: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
lisa_n_layers: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "the number of activate layers in LISA"},
|
||||||
|
)
|
||||||
|
lisa_step_interval: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "how often to switch layers in LISA"},
|
||||||
|
)
|
||||||
|
lisa_layers_attribute: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "path under the model to access the layers"},
|
||||||
|
)
|
||||||
|
curriculum_sampling: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||||
|
)
|
||||||
|
alternate_lr_scheduler_type: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
chat_template: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Chat template converting chat messages to text"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# kd_ce_alpha: Optional[float] = field(
|
||||||
|
# default=None,
|
||||||
|
# metadata={
|
||||||
|
# "help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# kd_alpha: Optional[float] = field(
|
||||||
|
# default=1.0,
|
||||||
|
# metadata={"help": "The alpha scaling parameter for KD loss"},
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# kd_temperature: Optional[float] = field(
|
||||||
|
# default=1.0,
|
||||||
|
# metadata={
|
||||||
|
# "help": "the temperature parameter for KL divergence loss when using KD"
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
|
||||||
|
adam_beta3: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
adam_epsilon2: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# multi-modal section
|
||||||
|
|
||||||
|
image_size: int | tuple[int, int] | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The size of the image to resize to"},
|
||||||
|
)
|
||||||
|
|
||||||
|
image_resize_algorithm: Resampling | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The algorithm to use for image resizing"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# end of multi-modal section
|
||||||
@@ -64,10 +64,6 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
desc="Strategy Filtering Rows",
|
desc="Strategy Filtering Rows",
|
||||||
)
|
)
|
||||||
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
ipdb.set_trace()
|
|
||||||
|
|
||||||
return dataset.map(
|
return dataset.map(
|
||||||
self.prompt_tokenizer.tokenize_prompt,
|
self.prompt_tokenizer.tokenize_prompt,
|
||||||
num_proc=num_proc,
|
num_proc=num_proc,
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
import importlib
|
import importlib
|
||||||
|
import traceback
|
||||||
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
||||||
|
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
@@ -83,6 +84,11 @@ class BasePlugin:
|
|||||||
def get_input_args(self) -> str | None:
|
def get_input_args(self) -> str | None:
|
||||||
"""Returns a pydantic model for the plugin's input arguments."""
|
"""Returns a pydantic model for the plugin's input arguments."""
|
||||||
|
|
||||||
|
def get_training_args_mixin(self) -> str | None:
|
||||||
|
"""
|
||||||
|
Returns a dataclass model for the plugin's training arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
def load_datasets(
|
def load_datasets(
|
||||||
self, cfg: DictDefault, preprocess: bool = False
|
self, cfg: DictDefault, preprocess: bool = False
|
||||||
) -> Union["TrainDatasetMeta", None]:
|
) -> Union["TrainDatasetMeta", None]:
|
||||||
@@ -158,6 +164,31 @@ class BasePlugin:
|
|||||||
trainer: The trainer object for training.
|
trainer: The trainer object for training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument):
|
||||||
|
"""
|
||||||
|
Returns custom training arguments to set on TrainingArgs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: The global axolotl configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: dict containing the training arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_collator_cls_and_kwargs(
|
||||||
|
self, cfg: DictDefault, is_eval: bool = False
|
||||||
|
): # pylint: disable=unused-argument):
|
||||||
|
"""
|
||||||
|
Returns a custom class for the collator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: The global axolotl configuration.
|
||||||
|
is_eval: Whether this is an eval split.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
class: The class for the collator.
|
||||||
|
"""
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
|
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
|
||||||
"""Creates and returns an optimizer for training.
|
"""Creates and returns an optimizer for training.
|
||||||
@@ -278,7 +309,7 @@ def load_plugin(plugin_name: str) -> BasePlugin:
|
|||||||
return plugin
|
return plugin
|
||||||
|
|
||||||
|
|
||||||
class PluginManager:
|
class PluginManager: # pylint: disable=too-many-public-methods
|
||||||
"""The `PluginManager` class is responsible for loading and managing plugins. It
|
"""The `PluginManager` class is responsible for loading and managing plugins. It
|
||||||
should be a singleton so it can be accessed from anywhere in the codebase.
|
should be a singleton so it can be accessed from anywhere in the codebase.
|
||||||
|
|
||||||
@@ -337,8 +368,11 @@ class PluginManager:
|
|||||||
plugin = load_plugin(plugin_name)
|
plugin = load_plugin(plugin_name)
|
||||||
self.plugins[plugin_name] = plugin
|
self.plugins[plugin_name] = plugin
|
||||||
LOG.info(f"Plugin loaded successfully: {plugin_name}")
|
LOG.info(f"Plugin loaded successfully: {plugin_name}")
|
||||||
except ImportError:
|
except ImportError as exc:
|
||||||
LOG.error(f"Failed to load plugin: {plugin_name}")
|
LOG.error(f"Failed to load plugin: {plugin_name}")
|
||||||
|
# print stacktrace
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Error: {exc}")
|
||||||
|
|
||||||
def get_input_args(self) -> list[str]:
|
def get_input_args(self) -> list[str]:
|
||||||
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
|
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
|
||||||
@@ -353,6 +387,20 @@ class PluginManager:
|
|||||||
input_args.append(input_args_from_plugin)
|
input_args.append(input_args_from_plugin)
|
||||||
return input_args
|
return input_args
|
||||||
|
|
||||||
|
def get_training_args_mixin(self):
|
||||||
|
"""
|
||||||
|
Returns a list of dataclasses for all registered plugins' training args mixins'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: A list of dataclsses
|
||||||
|
"""
|
||||||
|
training_args = []
|
||||||
|
for plugin in self.plugins.values():
|
||||||
|
training_args_from_plugin = plugin.get_training_args_mixin()
|
||||||
|
if training_args_from_plugin is not None:
|
||||||
|
training_args.append(training_args_from_plugin)
|
||||||
|
return training_args
|
||||||
|
|
||||||
def load_datasets(
|
def load_datasets(
|
||||||
self, cfg: DictDefault, preprocess: bool = False
|
self, cfg: DictDefault, preprocess: bool = False
|
||||||
) -> Union["TrainDatasetMeta", None]:
|
) -> Union["TrainDatasetMeta", None]:
|
||||||
@@ -442,6 +490,42 @@ class PluginManager:
|
|||||||
return trainer_cls
|
return trainer_cls
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_training_args(self, cfg):
|
||||||
|
"""
|
||||||
|
Calls the get_training_args method of all registered plugins and returns the combined training arguments.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
cfg (dict): The configuration for the plugins.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: The training arguments
|
||||||
|
"""
|
||||||
|
training_args_kwargs = {}
|
||||||
|
for plugin in self.plugins.values():
|
||||||
|
training_args = plugin.get_training_args(cfg)
|
||||||
|
if training_args is not None:
|
||||||
|
training_args_kwargs.update(training_args)
|
||||||
|
|
||||||
|
return training_args_kwargs
|
||||||
|
|
||||||
|
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
|
||||||
|
"""
|
||||||
|
Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
cfg (dict): The configuration for the plugins.
|
||||||
|
is_eval (bool): Whether this is an eval split.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: The collator class, or None if none was found.
|
||||||
|
"""
|
||||||
|
for plugin in self.plugins.values():
|
||||||
|
collator = plugin.get_collator_cls_and_kwargs(cfg, is_eval=is_eval)
|
||||||
|
if collator is not None:
|
||||||
|
collator_cls, collator_kwargs = collator
|
||||||
|
return collator_cls, collator_kwargs
|
||||||
|
return None
|
||||||
|
|
||||||
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
|
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
|
||||||
"""Calls the `post_trainer_create` method of all registered plugins.
|
"""Calls the `post_trainer_create` method of all registered plugins.
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ Module to handle merging the plugins' input arguments with the base configuratio
|
|||||||
This was moved here to prevent circular imports.
|
This was moved here to prevent circular imports.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Type
|
||||||
|
|
||||||
from axolotl.utils.schemas.config import (
|
from axolotl.utils.schemas.config import (
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
@@ -61,3 +61,43 @@ def merge_input_args():
|
|||||||
]
|
]
|
||||||
return AxolotlConfigWCapabilities, AxolotlInputConfig
|
return AxolotlConfigWCapabilities, AxolotlInputConfig
|
||||||
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
|
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
|
||||||
|
|
||||||
|
|
||||||
|
def merge_training_args() -> Type:
|
||||||
|
"""
|
||||||
|
Merges training arguments from registered plugins with the base TrainingArguments.
|
||||||
|
|
||||||
|
This function retrieves the training arguments from registered plugins using the PluginManager.
|
||||||
|
It then dynamically creates new classes, AxolotlTrainingMixins,
|
||||||
|
that inherit from the base configurations and include the training arguments from the plugins.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A tuple containing the newly created classes, AxolotlTrainingMixins.
|
||||||
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
from axolotl.core.training_args_base import (
|
||||||
|
AxolotlTrainingMixins as AxolotlTrainingMixinsBase,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.base import PluginManager
|
||||||
|
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
training_args_mixins: List[str] = plugin_manager.get_training_args_mixin()
|
||||||
|
mixin_classes = []
|
||||||
|
dynamic_input = ""
|
||||||
|
for plugin_args in training_args_mixins:
|
||||||
|
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
|
||||||
|
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
|
||||||
|
mixin_classes.append(plugin_cls)
|
||||||
|
if dynamic_input:
|
||||||
|
dynamic_input += f"class AxolotlTrainingMixins(AxolotlTrainingMixinsBase, {', '.join(mixin_classes)}):\n pass\n"
|
||||||
|
|
||||||
|
namespace: Dict[Any, Any] = {}
|
||||||
|
local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase}
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
dynamic_input, {**globals(), **local_vars}, namespace
|
||||||
|
)
|
||||||
|
AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name
|
||||||
|
"AxolotlTrainingMixins"
|
||||||
|
]
|
||||||
|
return AxolotlTrainingMixins
|
||||||
|
return AxolotlTrainingMixinsBase
|
||||||
|
|||||||
@@ -21,3 +21,32 @@ datasets:
|
|||||||
```
|
```
|
||||||
|
|
||||||
An example dataset can be found at [`axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample`](https://huggingface.co/datasets/axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample)
|
An example dataset can be found at [`axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample`](https://huggingface.co/datasets/axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample)
|
||||||
|
|
||||||
|
## Online KD (sglang)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export UV_TORCH_BACKEND=cu124
|
||||||
|
uv venv sglang --python 3.11
|
||||||
|
source sglang/bin/activate
|
||||||
|
uv pip install --upgrade pip
|
||||||
|
uv pip install setuptools
|
||||||
|
uv pip install torch~=2.5.1 --index-url https://download.pytorch.org/whl/cu124
|
||||||
|
uv pip install sgl-kernel --force-reinstall --no-deps
|
||||||
|
uv pip install "sglang[all]>=0.4.2.post4" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/
|
||||||
|
```
|
||||||
|
|
||||||
|
## Online KD (vllm)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
VLLM_USE_V1=0 vllm serve open-r1/OlympicCoder-32B --max-model-len 16400 --port 8888 --max-logprobs 128 --return-tokens-as-token-ids --tensor-parallel-size 8 --max-num-seqs
|
||||||
|
256 --gpu_memory_utilization 0.2 --enable-chunked-prefill
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
vllm serve open-r1/OlympicCoder-32B --max-model-len 16400 --port 8888 --max-logprobs 128 --return-tokens-as-token-ids --tensor-parallel-size 8 --no-enable-prefix-caching --gpu-memory-utilization 0.3 --max-num-batched-tokens 131072 --host 0.0.0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m sglang.launch_server --model-path open-r1/OlympicCoder-32B --tensor-parallel-size 8 --port 8080 --host 0.0.0.0 --max-running-requests 256 --context-length 16400 --mem-fraction-static 0.2 --schedule-conservativeness 0.3 --chunked-prefill-size 131072 --schedule-policy fcfs --skip-tokenizer-init
|
||||||
|
```
|
||||||
|
|||||||
@@ -15,7 +15,12 @@
|
|||||||
"""
|
"""
|
||||||
Plugin init to add KD support to Axolotl.
|
Plugin init to add KD support to Axolotl.
|
||||||
"""
|
"""
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback
|
||||||
|
|
||||||
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
@@ -28,9 +33,75 @@ class KDPlugin(BasePlugin):
|
|||||||
def get_input_args(self):
|
def get_input_args(self):
|
||||||
return "axolotl.integrations.kd.KDArgs"
|
return "axolotl.integrations.kd.KDArgs"
|
||||||
|
|
||||||
|
def get_training_args_mixin(self):
|
||||||
|
return "axolotl.integrations.kd.args.KDTrainingArgsMixin"
|
||||||
|
|
||||||
def get_trainer_cls(self, cfg):
|
def get_trainer_cls(self, cfg):
|
||||||
if cfg.kd_trainer:
|
if cfg.kd_trainer:
|
||||||
from .trainer import AxolotlKDTrainer
|
from .trainer import AxolotlKDTrainer
|
||||||
|
|
||||||
return AxolotlKDTrainer
|
return AxolotlKDTrainer
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_training_args(self, cfg):
|
||||||
|
return {
|
||||||
|
"kd_ce_alpha": cfg.kd_ce_alpha,
|
||||||
|
"kd_alpha": cfg.kd_alpha,
|
||||||
|
"kd_temperature": cfg.kd_temperature,
|
||||||
|
"kd_beta": cfg.kd_beta,
|
||||||
|
"kd_normalize_topk": cfg.kd_normalize_topk,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
|
||||||
|
if not cfg.kd_trainer:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
use_batch_sampler_collator = False
|
||||||
|
if is_eval is False and cfg.sample_packing:
|
||||||
|
use_batch_sampler_collator = True
|
||||||
|
if cfg.eval_sample_packing and is_eval:
|
||||||
|
use_batch_sampler_collator = True
|
||||||
|
|
||||||
|
if cfg.kd_online_server_base_url:
|
||||||
|
from .collator_online_teacher import OnlineTeacherCollator
|
||||||
|
|
||||||
|
return OnlineTeacherCollator, {
|
||||||
|
"kd_online_server_base_url": cfg.kd_online_server_base_url,
|
||||||
|
"kd_online_topk": cfg.kd_online_topk,
|
||||||
|
"kd_temperature": cfg.kd_temperature,
|
||||||
|
"kd_online_server": cfg.kd_online_server,
|
||||||
|
"kd_online_timeout": cfg.kd_online_timeout,
|
||||||
|
"kd_normalize_topk": cfg.kd_normalize_topk,
|
||||||
|
}
|
||||||
|
|
||||||
|
if use_batch_sampler_collator:
|
||||||
|
return KDBatchSamplerDataCollatorForSeq2Seq, {}
|
||||||
|
return DataCollatorForKD, {}
|
||||||
|
|
||||||
|
def pre_model_load(self, cfg):
|
||||||
|
from .kernels.models import apply_kernel
|
||||||
|
|
||||||
|
apply_kernel(cfg.model_config_type)
|
||||||
|
|
||||||
|
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
|
||||||
|
"""
|
||||||
|
Adds temp scheduler callback to the Trainer instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (Any): Configuration object containing the sparse recipe.
|
||||||
|
trainer (Trainer): Huggingface Trainer instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List containing the configured callback instances.
|
||||||
|
"""
|
||||||
|
if cfg.kd_temperature_min is not None and cfg.kd_online_server_base_url:
|
||||||
|
callback = KDTemperatureSchedulerCallback(
|
||||||
|
cfg.kd_temperature,
|
||||||
|
cfg.kd_temperature_min,
|
||||||
|
trainer,
|
||||||
|
)
|
||||||
|
return [callback]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|||||||
@@ -15,9 +15,19 @@
|
|||||||
"""
|
"""
|
||||||
Plugin args for KD support.
|
Plugin args for KD support.
|
||||||
"""
|
"""
|
||||||
from typing import Optional
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceServerType(str, Enum):
|
||||||
|
"""
|
||||||
|
Online inferences server types to handle different request args
|
||||||
|
"""
|
||||||
|
|
||||||
|
vllm = "vllm" # pylint: disable=invalid-name
|
||||||
|
sglang = "sglang" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class KDArgs(BaseModel):
|
class KDArgs(BaseModel):
|
||||||
@@ -25,13 +35,41 @@ class KDArgs(BaseModel):
|
|||||||
Input args for knowledge distillation.
|
Input args for knowledge distillation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
kd_trainer: Optional[bool] = None # whether to use KD trainer
|
kd_trainer: float | None = None # whether to use KD trainer
|
||||||
kd_ce_alpha: Optional[float] = (
|
kd_ce_alpha: float | None = (
|
||||||
None # loss coefficient for cross-entropy loss during KD
|
None # loss coefficient for cross-entropy loss during KD
|
||||||
)
|
)
|
||||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
kd_alpha: float | None = None # loss coefficient for KD loss
|
||||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
kd_temperature: float | None = None # temperature for sampling during KD
|
||||||
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
|
||||||
kd_top_k_before_softmax: Optional[bool] = (
|
kd_normalize_topk: bool | None = (
|
||||||
None # whether to sample top k before softmax during KD
|
None # whether to normalize student logits during KD
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO online kd
|
||||||
|
kd_online_server_base_url: str | None = None
|
||||||
|
kd_online_topk: int | None = None
|
||||||
|
kd_online_server: InferenceServerType | None = Field(
|
||||||
|
default_factory=lambda: InferenceServerType.vllm
|
||||||
|
)
|
||||||
|
kd_online_timeout: int | None = 120
|
||||||
|
kd_temperature_min: float | None = (
|
||||||
|
None # kd temperature scheduling during online kd
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KDTrainingArgsMixin:
|
||||||
|
"""
|
||||||
|
Additional args for KD training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kd_ce_alpha: float | None = (
|
||||||
|
None # loss coefficient for cross-entropy loss during KD
|
||||||
|
)
|
||||||
|
kd_alpha: float | None = None # loss coefficient for KD loss
|
||||||
|
kd_temperature: float | None = None # temperature for sampling during KD
|
||||||
|
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
|
||||||
|
kd_normalize_topk: float | None = (
|
||||||
|
None # whether to normalize student logits during KD
|
||||||
)
|
)
|
||||||
|
|||||||
36
src/axolotl/integrations/kd/callbacks.py
Normal file
36
src/axolotl/integrations/kd/callbacks.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""
|
||||||
|
Transformers trainer callbacks to schedule the KD temperature during training
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
from transformers.trainer_callback import TrainerCallback
|
||||||
|
|
||||||
|
|
||||||
|
class KDTemperatureSchedulerCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
KD temperature scheduler callback for the trainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, temperature_start, temperature_min, trainer):
|
||||||
|
self.temperature_start = temperature_start
|
||||||
|
self.temperature_min = temperature_min
|
||||||
|
self.temperature = temperature_start
|
||||||
|
|
||||||
|
self.trainer = trainer
|
||||||
|
|
||||||
|
def on_step_end(
|
||||||
|
self, args, state, control, **kwargs
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
# cosine decay temperature over the max steps
|
||||||
|
|
||||||
|
progress = state.global_step / state.max_steps
|
||||||
|
# Cosine decay factor: 0.5 * (1 + cos(pi * progress))
|
||||||
|
# This factor goes from 1 (at progress=0) to 0 (at progress=1)
|
||||||
|
decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
|
||||||
|
self.temperature = self.temperature_start - (
|
||||||
|
(self.temperature_start - self.temperature_min) * (1.0 - decay_factor)
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(self.trainer.data_collator, "kd_temperature"):
|
||||||
|
self.trainer.data_collator.kd_temperature = self.temperature
|
||||||
@@ -15,12 +15,15 @@
|
|||||||
"""
|
"""
|
||||||
Chat template prompt strategy loader with KD support
|
Chat template prompt strategy loader with KD support
|
||||||
"""
|
"""
|
||||||
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
|
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||||
"""
|
"""
|
||||||
@@ -101,10 +104,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
# fill with -inf for padding_len tokens for top_k tokens
|
# fill with -inf for padding_len tokens for top_k tokens
|
||||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||||
|
|
||||||
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
|
# we shift for causal models in the trainer, so start the range from 0
|
||||||
# otherwise, we need to shift in the trainer
|
for _ in range(0, input_padding_len):
|
||||||
shift = 0
|
|
||||||
for _ in range(shift, input_padding_len):
|
|
||||||
target_logprobs.append([-float("inf")] * top_k)
|
target_logprobs.append([-float("inf")] * top_k)
|
||||||
target_token_ids.append(list(range(top_k)))
|
target_token_ids.append(list(range(top_k)))
|
||||||
target_mask.append([0] * top_k)
|
target_mask.append([0] * top_k)
|
||||||
@@ -143,6 +144,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
#
|
#
|
||||||
# Convert from log to probability
|
# Convert from log to probability
|
||||||
teacher_probs_t1 = position_logprobs_tensor.exp()
|
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||||
|
# normalize probabilities to sum to 1 in case they aren't already
|
||||||
|
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
|
||||||
|
if teacher_probs_t1_sum > 1e-9:
|
||||||
|
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
|
||||||
if self.kd_temperature != self.gen_temperature:
|
if self.kd_temperature != self.gen_temperature:
|
||||||
# Exponentiate by factor (T1 / T2)
|
# Exponentiate by factor (T1 / T2)
|
||||||
exponent = self.gen_temperature / self.kd_temperature
|
exponent = self.gen_temperature / self.kd_temperature
|
||||||
@@ -162,12 +167,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
target_logprobs.append(position_logprobs_scaled)
|
target_logprobs.append(position_logprobs_scaled)
|
||||||
target_token_ids.append(position_token_ids)
|
target_token_ids.append(position_token_ids)
|
||||||
|
|
||||||
if shift == 1:
|
|
||||||
# since we started at index 1 for causal, we need one more padding token
|
|
||||||
target_logprobs.append([-float("inf")] * top_k)
|
|
||||||
target_token_ids.append(list(range(top_k)))
|
|
||||||
target_mask.append([0] * top_k)
|
|
||||||
|
|
||||||
# Update sample with transformed logprobs
|
# Update sample with transformed logprobs
|
||||||
sample["target_logprobs"] = target_logprobs
|
sample["target_logprobs"] = target_logprobs
|
||||||
sample["target_token_ids"] = target_token_ids
|
sample["target_token_ids"] = target_token_ids
|
||||||
@@ -184,13 +183,124 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
return tokenized_prompt
|
return tokenized_prompt
|
||||||
|
|
||||||
|
|
||||||
|
class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
|
||||||
|
"""
|
||||||
|
Strat for datasets with complete structured KD logprob data
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_logprobs(self, sample):
|
||||||
|
"""
|
||||||
|
Transform logprobs to target format for KD training
|
||||||
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
logprobs = sample.pop(self.logprobs_field)
|
||||||
|
target_seq_len = len(logprobs)
|
||||||
|
input_seq_len = len(sample["input_ids"])
|
||||||
|
input_padding_len = input_seq_len - target_seq_len
|
||||||
|
# get non-zero top-k (prune None logprobs from vllm data step)
|
||||||
|
top_k_vals = [
|
||||||
|
len(logprobs[i])
|
||||||
|
for i in range(len(logprobs))
|
||||||
|
if logprobs[i] is not None and len(logprobs[i])
|
||||||
|
]
|
||||||
|
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
|
||||||
|
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
|
||||||
|
top_k = min(max_top_k, min_top_k)
|
||||||
|
if top_k == 0:
|
||||||
|
raise ValueError("No non-zero top-k logprobs found.")
|
||||||
|
|
||||||
|
target_logprobs = []
|
||||||
|
target_token_ids = []
|
||||||
|
target_mask = []
|
||||||
|
|
||||||
|
if input_padding_len < 0:
|
||||||
|
# logprobs is longer than target_seq_len,
|
||||||
|
# so we need to slice from the left/beginning of logprobs
|
||||||
|
logprobs = logprobs[:-input_seq_len]
|
||||||
|
input_padding_len = 0
|
||||||
|
# target_seq_len = input_seq_len
|
||||||
|
|
||||||
|
# truncate the second dimension of the logprobs to top_k
|
||||||
|
logprobs = [row[:top_k] for row in logprobs]
|
||||||
|
|
||||||
|
# fill with -inf for padding_len tokens for top_k tokens
|
||||||
|
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||||
|
|
||||||
|
# we shift for causal models in the trainer, so start the range from 0
|
||||||
|
for _ in range(0, input_padding_len):
|
||||||
|
target_logprobs.append([-float("inf")] * top_k)
|
||||||
|
target_token_ids.append(list(range(top_k)))
|
||||||
|
target_mask.append([0] * top_k)
|
||||||
|
|
||||||
|
for position in range(input_padding_len, input_seq_len):
|
||||||
|
if sample["labels"][position] == -100:
|
||||||
|
target_mask.append([0] * top_k)
|
||||||
|
else:
|
||||||
|
target_mask.append([1] * top_k)
|
||||||
|
|
||||||
|
for token_pos_logprobs, pos_target_token_ids in zip(
|
||||||
|
logprobs, sample["target_token_ids"]
|
||||||
|
):
|
||||||
|
# Convert to a tensor for easier manipulation
|
||||||
|
position_logprobs_tensor = torch.tensor(
|
||||||
|
token_pos_logprobs, dtype=torch.float
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
|
||||||
|
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
|
||||||
|
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
|
||||||
|
#
|
||||||
|
# Convert from log to probability
|
||||||
|
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||||
|
# normalize probabilities to sum to 1 in case they aren't already
|
||||||
|
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
|
||||||
|
if teacher_probs_t1_sum > 1e-9:
|
||||||
|
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
|
||||||
|
if self.kd_temperature != self.gen_temperature:
|
||||||
|
# Exponentiate by factor (T1 / T2)
|
||||||
|
exponent = self.gen_temperature / self.kd_temperature
|
||||||
|
teacher_probs_t2 = teacher_probs_t1**exponent
|
||||||
|
else:
|
||||||
|
teacher_probs_t2 = teacher_probs_t1
|
||||||
|
# Re-normalize
|
||||||
|
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
||||||
|
dim=0, keepdim=True
|
||||||
|
)
|
||||||
|
# Convert back to log
|
||||||
|
position_logprobs_tensor = torch.log(teacher_probs_t2)
|
||||||
|
|
||||||
|
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
|
||||||
|
position_logprobs_scaled = position_logprobs_tensor.tolist()
|
||||||
|
|
||||||
|
target_logprobs.append(position_logprobs_scaled)
|
||||||
|
target_token_ids.append(pos_target_token_ids)
|
||||||
|
|
||||||
|
# Update sample with transformed logprobs
|
||||||
|
sample["target_logprobs"] = target_logprobs
|
||||||
|
sample["target_token_ids"] = target_token_ids
|
||||||
|
sample["target_mask"] = target_mask
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def _tokenize_single_prompt(self, prompt):
|
||||||
|
logprobs = prompt.pop(self.logprobs_field)
|
||||||
|
target_token_ids = prompt.pop("target_token_ids")
|
||||||
|
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||||
|
tokenized_prompt[self.logprobs_field] = logprobs
|
||||||
|
tokenized_prompt["target_token_ids"] = target_token_ids
|
||||||
|
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||||
|
|
||||||
|
return tokenized_prompt
|
||||||
|
|
||||||
|
|
||||||
class KDStrategyLoader(StrategyLoader):
|
class KDStrategyLoader(StrategyLoader):
|
||||||
"""
|
"""
|
||||||
Load ChatTemplateStrategy with KD support using StrategyLoader.
|
Load ChatTemplateStrategy with KD support using StrategyLoader.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_strategy_cls(self):
|
def _get_strategy_cls(self):
|
||||||
return ChatTemplateStrategyWithKD
|
return ChatTemplateStrategyWithKDv2
|
||||||
|
|
||||||
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
|
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
|
||||||
strategy_params = super()._get_strategy_params(cfg, ds_cfg)
|
strategy_params = super()._get_strategy_params(cfg, ds_cfg)
|
||||||
|
|||||||
@@ -47,11 +47,16 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
|||||||
position_pad_token_id: int = 0
|
position_pad_token_id: int = 0
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
if return_tensors is None:
|
if return_tensors is None:
|
||||||
return_tensors = self.return_tensors
|
return_tensors = self.return_tensors
|
||||||
|
|
||||||
padding_side = self.tokenizer.padding_side
|
padding_side = self.tokenizer.padding_side
|
||||||
|
max_len = 0
|
||||||
|
|
||||||
# Pad labels and position_ids first
|
# Pad labels and position_ids first
|
||||||
for feature_name, pad_token_id in [
|
for feature_name, pad_token_id in [
|
||||||
@@ -102,7 +107,9 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
|||||||
target_mask_list.append(f.pop("target_mask"))
|
target_mask_list.append(f.pop("target_mask"))
|
||||||
|
|
||||||
# Determine max lengths
|
# Determine max lengths
|
||||||
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
|
max_teacher_seq_len = max_len or max(
|
||||||
|
len(seq) for seq in target_logprobs_list
|
||||||
|
)
|
||||||
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
|
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
|
||||||
|
|
||||||
padded_target_logprobs = []
|
padded_target_logprobs = []
|
||||||
@@ -209,7 +216,9 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
|
|||||||
# We want to produce a single "merged" feature dict for each sub-batch.
|
# We want to produce a single "merged" feature dict for each sub-batch.
|
||||||
out_features = [{} for _ in features]
|
out_features = [{} for _ in features]
|
||||||
|
|
||||||
for i, sub_features in enumerate(features):
|
for i, sub_features in enumerate( # pylint: disable=too-many-nested-blocks
|
||||||
|
features
|
||||||
|
):
|
||||||
# sub_features is a list of dicts, each dict = one sequence’s features
|
# sub_features is a list of dicts, each dict = one sequence’s features
|
||||||
# We'll merge them into out_features[i].
|
# We'll merge them into out_features[i].
|
||||||
#
|
#
|
||||||
@@ -243,10 +252,17 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
|
|||||||
# For example, input_ids or labels are often arrays.
|
# For example, input_ids or labels are often arrays.
|
||||||
arrays = []
|
arrays = []
|
||||||
for feat in sub_features:
|
for feat in sub_features:
|
||||||
if field_name in feat:
|
if field_name in feat and isinstance(
|
||||||
|
feat[field_name], (list, torch.Tensor)
|
||||||
|
):
|
||||||
|
if isinstance(
|
||||||
|
feat[field_name][0], (dict, str)
|
||||||
|
): # pylint: disable=too-many-nested-blocks
|
||||||
|
continue
|
||||||
arr = np.array(feat[field_name])
|
arr = np.array(feat[field_name])
|
||||||
arrays.append(arr)
|
arrays.append(arr)
|
||||||
out_features[i][field_name] = np.concatenate(arrays)
|
if arrays:
|
||||||
|
out_features[i][field_name] = np.concatenate(arrays)
|
||||||
|
|
||||||
# 3) Now call the parent collator, which will do:
|
# 3) Now call the parent collator, which will do:
|
||||||
# - padding of labels/position_ids
|
# - padding of labels/position_ids
|
||||||
|
|||||||
561
src/axolotl/integrations/kd/collator_online_teacher.py
Normal file
561
src/axolotl/integrations/kd/collator_online_teacher.py
Normal file
@@ -0,0 +1,561 @@
|
|||||||
|
"""
|
||||||
|
Packed data loader for online teacher training supporting vllm and sglang.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
from orjson import orjson
|
||||||
|
|
||||||
|
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
|
||||||
|
from axolotl.integrations.kd.utils import normalize_logprobs
|
||||||
|
from axolotl.utils.data.utils import retry_on_request_exceptions
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def hmac_sha_from_int_list(int_list, key, hash_func=hashlib.sha256):
|
||||||
|
"""
|
||||||
|
Create HMAC-SHA hash from a list of integers
|
||||||
|
|
||||||
|
Args:
|
||||||
|
int_list: List of integers
|
||||||
|
key: Secret key (string or bytes)
|
||||||
|
hash_func: Hash function (default: sha256)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HMAC digest as hex string
|
||||||
|
"""
|
||||||
|
# Convert key to bytes if it's a string
|
||||||
|
if isinstance(key, str):
|
||||||
|
key = key.encode("utf-8")
|
||||||
|
|
||||||
|
# Convert list of ints to bytes
|
||||||
|
# Method 1: Convert each int to bytes and concatenate
|
||||||
|
data = b"".join(i.to_bytes(4, byteorder="big") for i in int_list)
|
||||||
|
|
||||||
|
# Create HMAC
|
||||||
|
h = hmac.new(key, data, hash_func)
|
||||||
|
return h.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||||
|
"""
|
||||||
|
Collator for online teacher training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_LABEL_PAD_TOKEN_ID: int = -100
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
kd_online_server_base_url: Optional[str] = None,
|
||||||
|
kd_online_topk: Optional[int] = None,
|
||||||
|
kd_temperature: Optional[float] = 1.0,
|
||||||
|
kd_online_server: Optional[str] = "vllm",
|
||||||
|
kd_online_timeout: Optional[int] = 120,
|
||||||
|
kd_cache_dir: Optional[str] = None,
|
||||||
|
kd_normalize_topk: Optional[bool] = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
if kd_online_server_base_url is None:
|
||||||
|
raise ValueError(
|
||||||
|
"kd_online_server_base_url must be provided for OnlineTeacherDataloader"
|
||||||
|
)
|
||||||
|
if kd_online_topk is None or kd_online_topk <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
"kd_online_topk must be a positive integer for OnlineTeacherDataloader"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kd_online_server_base_url = kd_online_server_base_url.rstrip("/")
|
||||||
|
self.kd_online_topk = kd_online_topk
|
||||||
|
self.kd_temperature = kd_temperature
|
||||||
|
self.kd_online_server = kd_online_server
|
||||||
|
self.http_session = requests.Session()
|
||||||
|
self.kd_online_timeout = kd_online_timeout
|
||||||
|
self.kd_cache_dir = kd_cache_dir
|
||||||
|
self.kd_normalize_topk = kd_normalize_topk
|
||||||
|
|
||||||
|
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
|
||||||
|
"""
|
||||||
|
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
|
||||||
|
"""
|
||||||
|
if not raw_logprobs or self.kd_online_topk == 0:
|
||||||
|
return (
|
||||||
|
[-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
|
||||||
|
return normalize_logprobs(raw_logprobs_tensor, self.kd_online_topk).tolist()
|
||||||
|
|
||||||
|
@retry_on_request_exceptions(max_retries=10, delay=5)
|
||||||
|
def fetch_online_logprobs_sglang(
|
||||||
|
self, batch_input_ids: List[List[int]], labels: List[List[int]]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Fetches logprobs from an online teacher served by sglang for a batch of input_ids.
|
||||||
|
Assumes API returns token IDs as strings in logprob dictionary keys.
|
||||||
|
"""
|
||||||
|
api_endpoint = f"{self.kd_online_server_base_url}/generate"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"input_ids": batch_input_ids,
|
||||||
|
"return_logprob": True,
|
||||||
|
"top_logprobs_num": self.kd_online_topk,
|
||||||
|
"logprob_start_len": 0,
|
||||||
|
"return_text_in_logprobs": True,
|
||||||
|
"echo": True,
|
||||||
|
"sampling_params": {
|
||||||
|
"max_new_tokens": 0,
|
||||||
|
"temperature": self.kd_temperature,
|
||||||
|
"skip_special_tokens": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize with empty lists, so if API call fails, these are returned.
|
||||||
|
ret_data_target_token_ids: List[List[List[int]]] = []
|
||||||
|
ret_data_target_logprobs: List[List[List[float]]] = []
|
||||||
|
ret_data_target_mask: List[List[List[int]]] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.http_session.post(
|
||||||
|
api_endpoint, json=payload, timeout=self.kd_online_timeout
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
api_data: list[dict] = response.json()
|
||||||
|
|
||||||
|
# Ensure api_data is a list, and its length matches batch_input_ids
|
||||||
|
if not isinstance(api_data, list) or len(api_data) != len(batch_input_ids):
|
||||||
|
LOG.error(
|
||||||
|
f"API response format error. Expected a list of {len(batch_input_ids)} "
|
||||||
|
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
|
||||||
|
)
|
||||||
|
# Return empty data; items processed later will get default empty KD fields
|
||||||
|
return {
|
||||||
|
"target_token_ids": ret_data_target_token_ids,
|
||||||
|
"target_logprobs": ret_data_target_logprobs,
|
||||||
|
"target_mask": ret_data_target_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
for sequence_data, seq_input_ids, seq_labels in zip(
|
||||||
|
api_data, batch_input_ids, labels
|
||||||
|
):
|
||||||
|
current_target_logprobs = []
|
||||||
|
current_target_token_ids = []
|
||||||
|
current_target_mask = []
|
||||||
|
|
||||||
|
meta_info = sequence_data.pop("meta_info", {})
|
||||||
|
# Ensure input_top_logprobs is a list
|
||||||
|
input_top_logprobs: Optional[list[None | list[tuple]]] = meta_info.pop(
|
||||||
|
"input_top_logprobs", []
|
||||||
|
)
|
||||||
|
if not isinstance(input_top_logprobs, list):
|
||||||
|
LOG.warning(
|
||||||
|
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
|
||||||
|
)
|
||||||
|
input_top_logprobs = [] # Treat as empty
|
||||||
|
|
||||||
|
# basic check that the logprob data len matches the input len, so no need to handle padding
|
||||||
|
assert len(seq_input_ids) == len(input_top_logprobs)
|
||||||
|
|
||||||
|
for i, _, label in zip(
|
||||||
|
range(len(seq_input_ids)), seq_input_ids, seq_labels
|
||||||
|
):
|
||||||
|
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
|
||||||
|
# this is always the case for the first token.
|
||||||
|
# there is never logprob data for the first token since that's a true input
|
||||||
|
# so we replace the None value with padding data
|
||||||
|
current_target_logprobs.append(
|
||||||
|
[-float("inf")] * self.kd_online_topk
|
||||||
|
)
|
||||||
|
current_target_token_ids.append([0] * self.kd_online_topk)
|
||||||
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
|
elif (
|
||||||
|
i < len(input_top_logprobs)
|
||||||
|
and input_top_logprobs[i] is not None
|
||||||
|
):
|
||||||
|
pos_top_logprobs_data = input_top_logprobs[i]
|
||||||
|
# Ensure pos_top_logprobs_data is a list of lists as expected
|
||||||
|
if not (
|
||||||
|
isinstance(pos_top_logprobs_data, list)
|
||||||
|
and all(
|
||||||
|
isinstance(item, list) for item in pos_top_logprobs_data
|
||||||
|
)
|
||||||
|
and len(pos_top_logprobs_data) > 0
|
||||||
|
and len(pos_top_logprobs_data[0]) == 3
|
||||||
|
): # [logprob, token_id, token_str]
|
||||||
|
LOG.warning(
|
||||||
|
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
|
||||||
|
)
|
||||||
|
current_target_logprobs.append(
|
||||||
|
[-float("inf")] * self.kd_online_topk
|
||||||
|
)
|
||||||
|
current_target_token_ids.append([0] * self.kd_online_topk)
|
||||||
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
|
||||||
|
pos_logprobs_raw, pos_token_ids, _ = [
|
||||||
|
list(row) for row in zip(*pos_top_logprobs_data)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Ensure correct length (top_k)
|
||||||
|
if len(pos_logprobs_raw) < self.kd_online_topk:
|
||||||
|
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
|
||||||
|
pos_logprobs_raw.extend([-float("inf")] * pad_len)
|
||||||
|
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
|
||||||
|
|
||||||
|
# truncate to top_k in case the response was longer
|
||||||
|
current_target_token_ids.append(
|
||||||
|
pos_token_ids[: self.kd_online_topk]
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.kd_normalize_topk:
|
||||||
|
normalized_logprobs_for_position = self._normalize_logprobs(
|
||||||
|
pos_logprobs_raw[: self.kd_online_topk]
|
||||||
|
)
|
||||||
|
current_target_logprobs.append(
|
||||||
|
normalized_logprobs_for_position
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
current_target_logprobs.append(
|
||||||
|
pos_logprobs_raw[: self.kd_online_topk]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mask depends on the corresponding label for the student
|
||||||
|
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
|
||||||
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
|
else:
|
||||||
|
current_target_mask.append([1] * self.kd_online_topk)
|
||||||
|
else:
|
||||||
|
# Pad if no logprobs for this position (either due to length mismatch or None entry)
|
||||||
|
current_target_logprobs.append(
|
||||||
|
[-float("inf")] * self.kd_online_topk
|
||||||
|
)
|
||||||
|
current_target_token_ids.append([0] * self.kd_online_topk)
|
||||||
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
|
|
||||||
|
ret_data_target_token_ids.append(current_target_token_ids)
|
||||||
|
ret_data_target_logprobs.append(current_target_logprobs)
|
||||||
|
ret_data_target_mask.append(current_target_mask)
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
LOG.error(f"Error fetching logprobs from online teacher: {e}")
|
||||||
|
raise e
|
||||||
|
# ret_logprobs_data will be returned with empty lists, handled by the caller.
|
||||||
|
except Exception as e: # Catch other potential errors during processing
|
||||||
|
LOG.error(
|
||||||
|
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return {
|
||||||
|
"target_token_ids": ret_data_target_token_ids,
|
||||||
|
"target_logprobs": ret_data_target_logprobs,
|
||||||
|
"target_mask": ret_data_target_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
@retry_on_request_exceptions(max_retries=10, delay=5)
|
||||||
|
def fetch_online_logprobs_vllm(
|
||||||
|
self, batch_input_ids: List[List[int]], labels: List[List[int]]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
|
||||||
|
Assumes API returns token IDs as strings in logprob dictionary keys.
|
||||||
|
"""
|
||||||
|
api_endpoint = f"{self.kd_online_server_base_url}/v1/completions"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"prompt": batch_input_ids,
|
||||||
|
"echo": True,
|
||||||
|
"logprobs": True,
|
||||||
|
"prompt_logprobs": self.kd_online_topk,
|
||||||
|
"top_logprobs": self.kd_online_topk,
|
||||||
|
"max_new_tokens": 0,
|
||||||
|
"skip_special_tokens": False,
|
||||||
|
"temperature": self.kd_temperature,
|
||||||
|
"sampling_params": {
|
||||||
|
"max_tokens": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize with empty lists, so if API call fails, these are returned.
|
||||||
|
ret_data_target_token_ids: List[List[List[int]]] = []
|
||||||
|
ret_data_target_logprobs: List[List[List[float]]] = []
|
||||||
|
ret_data_target_mask: List[List[List[int]]] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
headers = {"Accept-Encoding": "deflate, gzip, br, zstd"}
|
||||||
|
response = self.http_session.post(
|
||||||
|
api_endpoint,
|
||||||
|
json=payload,
|
||||||
|
headers=headers,
|
||||||
|
timeout=self.kd_online_timeout,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
api_data: dict = orjson.loads(response.content)
|
||||||
|
choices: list[dict] = api_data["choices"]
|
||||||
|
|
||||||
|
# Ensure api_data is a list, and its length matches batch_input_ids
|
||||||
|
if not isinstance(choices, list) or len(choices) != len(batch_input_ids):
|
||||||
|
LOG.error(
|
||||||
|
f"API response format error. Expected a list of {len(batch_input_ids)} "
|
||||||
|
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
|
||||||
|
)
|
||||||
|
# Return empty data; items processed later will get default empty KD fields
|
||||||
|
return {
|
||||||
|
"target_token_ids": ret_data_target_token_ids,
|
||||||
|
"target_logprobs": ret_data_target_logprobs,
|
||||||
|
"target_mask": ret_data_target_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
for sequence_data, seq_input_ids, seq_labels in zip(
|
||||||
|
choices, batch_input_ids, labels
|
||||||
|
):
|
||||||
|
# seq_input_ids: List[int]
|
||||||
|
# seq_labels: List[int]
|
||||||
|
|
||||||
|
current_target_logprobs = []
|
||||||
|
current_target_token_ids = []
|
||||||
|
current_target_mask = []
|
||||||
|
|
||||||
|
# Ensure input_top_logprobs is a list
|
||||||
|
input_top_logprobs: Optional[list[None | dict[str, dict]]] = (
|
||||||
|
sequence_data.pop("prompt_logprobs", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(input_top_logprobs, list):
|
||||||
|
LOG.warning(
|
||||||
|
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
|
||||||
|
)
|
||||||
|
input_top_logprobs = [] # Treat as empty
|
||||||
|
|
||||||
|
# basic check that the logprob data len matches the input len, so no need to handle padding
|
||||||
|
assert len(seq_input_ids) == len(input_top_logprobs)
|
||||||
|
|
||||||
|
seq_len = len(seq_input_ids)
|
||||||
|
|
||||||
|
for i, _, label in zip(range(seq_len), seq_input_ids, seq_labels):
|
||||||
|
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
|
||||||
|
# this is always the case for the first token.
|
||||||
|
# there is never logprob data for the first token since that's a true input
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
i < len(input_top_logprobs)
|
||||||
|
and input_top_logprobs[i] is not None
|
||||||
|
):
|
||||||
|
pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] # type: ignore[assignment]
|
||||||
|
# Ensure pos_top_logprobs_data is a list of lists as expected
|
||||||
|
if not (
|
||||||
|
isinstance(pos_top_logprobs_data, dict)
|
||||||
|
and all(
|
||||||
|
isinstance(item, dict)
|
||||||
|
for item in pos_top_logprobs_data.values()
|
||||||
|
)
|
||||||
|
and len(pos_top_logprobs_data.keys()) > 0
|
||||||
|
): # [logprob, token_id, token_str]
|
||||||
|
LOG.warning(
|
||||||
|
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
|
||||||
|
)
|
||||||
|
current_target_logprobs.append(
|
||||||
|
[-float("inf")] * self.kd_online_topk
|
||||||
|
)
|
||||||
|
current_target_token_ids.append(
|
||||||
|
list(range(self.kd_online_topk))
|
||||||
|
)
|
||||||
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
|
||||||
|
pos_token_ids_str = list(pos_top_logprobs_data.keys())
|
||||||
|
pos_logprobs_dict = pos_top_logprobs_data.values()
|
||||||
|
pos_token_ids = [
|
||||||
|
int(token_id) for token_id in pos_token_ids_str
|
||||||
|
]
|
||||||
|
pos_logprobs_raw = [
|
||||||
|
float(logprob.get("logprob", -float("inf")))
|
||||||
|
for logprob in pos_logprobs_dict
|
||||||
|
]
|
||||||
|
|
||||||
|
# Ensure correct length (top_k)
|
||||||
|
if len(pos_logprobs_raw) < self.kd_online_topk:
|
||||||
|
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
|
||||||
|
LOG.warning(
|
||||||
|
f"Padding position {i} with {pad_len} top-k tokens and logprobs."
|
||||||
|
)
|
||||||
|
pos_logprobs_raw.extend([-float("inf")] * pad_len)
|
||||||
|
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
|
||||||
|
|
||||||
|
# truncate to top_k in case the response was longer
|
||||||
|
current_target_token_ids.append(
|
||||||
|
pos_token_ids[: self.kd_online_topk]
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.kd_normalize_topk:
|
||||||
|
normalized_logprobs_for_position = self._normalize_logprobs(
|
||||||
|
pos_logprobs_raw[: self.kd_online_topk]
|
||||||
|
)
|
||||||
|
current_target_logprobs.append(
|
||||||
|
normalized_logprobs_for_position
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
current_target_logprobs.append(
|
||||||
|
pos_logprobs_raw[: self.kd_online_topk]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mask depends on the corresponding label for the student
|
||||||
|
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
|
||||||
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
|
else:
|
||||||
|
current_target_mask.append([1] * self.kd_online_topk)
|
||||||
|
else:
|
||||||
|
# Pad if no logprobs for this position (either due to length mismatch or None entry)
|
||||||
|
current_target_logprobs.append(
|
||||||
|
[-float("inf")] * self.kd_online_topk
|
||||||
|
)
|
||||||
|
current_target_token_ids.append(
|
||||||
|
list(range(self.kd_online_topk))
|
||||||
|
)
|
||||||
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
|
for i in range(max(0, seq_len - len(current_target_logprobs))):
|
||||||
|
current_target_logprobs.append(
|
||||||
|
[-float("inf")] * self.kd_online_topk
|
||||||
|
)
|
||||||
|
current_target_token_ids.append(list(range(self.kd_online_topk)))
|
||||||
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
|
|
||||||
|
ret_data_target_token_ids.append(current_target_token_ids)
|
||||||
|
ret_data_target_logprobs.append(current_target_logprobs)
|
||||||
|
ret_data_target_mask.append(current_target_mask)
|
||||||
|
|
||||||
|
# TODO save and load targets to disk for caching for next epoch
|
||||||
|
# generate a hmac SHA256 hash over the list seq_input_ids and convert it to an int
|
||||||
|
# if self.kd_cache_dir:
|
||||||
|
# hash_input_ids = hmac_sha_from_int_list(
|
||||||
|
# seq_input_ids, f"{self.kd_online_server_base_url}:{self.kd_online_topk}"
|
||||||
|
# )
|
||||||
|
# with open(f"{self.kd_cache_dir}/{hash_input_ids}.parquet", "wb") as f:
|
||||||
|
# pd.DataFrame(ret_logprobs_data).to_parquet(f, index=False)
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
LOG.error(f"Error fetching logprobs from online teacher: {e}")
|
||||||
|
raise e
|
||||||
|
# ret_logprobs_data will be returned with empty lists, handled by the caller.
|
||||||
|
except Exception as e: # Catch other potential errors during processing
|
||||||
|
LOG.error(
|
||||||
|
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return {
|
||||||
|
"target_token_ids": ret_data_target_token_ids,
|
||||||
|
"target_logprobs": ret_data_target_logprobs,
|
||||||
|
"target_mask": ret_data_target_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
if not features:
|
||||||
|
return super().__call__(features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
for (
|
||||||
|
sub_batch_features
|
||||||
|
) in features: # sub_batch_features is List[Dict[str, Any]]
|
||||||
|
if not sub_batch_features:
|
||||||
|
continue
|
||||||
|
|
||||||
|
input_ids_for_api_call: List[List[int]] = []
|
||||||
|
labels_for_api_call: List[List[int]] = []
|
||||||
|
# Store references to the original item dictionaries to update them in-place
|
||||||
|
items_for_api_call: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
for item_dict in sub_batch_features:
|
||||||
|
if not isinstance(item_dict, dict):
|
||||||
|
LOG.warning(
|
||||||
|
f"Skipping non-dict item in sub_batch_features: {item_dict}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_input_ids = item_dict.get("input_ids")
|
||||||
|
current_labels = item_dict.get("labels")
|
||||||
|
|
||||||
|
if current_input_ids is not None and current_labels is not None:
|
||||||
|
# Ensure input_ids and labels are lists of ints for JSON serialization
|
||||||
|
input_ids_list = (
|
||||||
|
current_input_ids.tolist()
|
||||||
|
if hasattr(current_input_ids, "tolist")
|
||||||
|
else list(current_input_ids)
|
||||||
|
)
|
||||||
|
labels_list = (
|
||||||
|
current_labels.tolist()
|
||||||
|
if hasattr(current_labels, "tolist")
|
||||||
|
else list(current_labels)
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids_for_api_call.append(input_ids_list)
|
||||||
|
labels_for_api_call.append(labels_list)
|
||||||
|
items_for_api_call.append(item_dict)
|
||||||
|
else:
|
||||||
|
# This item will not get teacher logprobs from the API.
|
||||||
|
# Initialize KD fields to empty lists so downstream collators handle them uniformly.
|
||||||
|
item_dict.setdefault("target_token_ids", [])
|
||||||
|
item_dict.setdefault("target_logprobs", [])
|
||||||
|
item_dict.setdefault("target_mask", [])
|
||||||
|
|
||||||
|
# print(items_for_api_call)
|
||||||
|
if items_for_api_call: # Only call API if there's something to process
|
||||||
|
if self.kd_online_server == "sglang":
|
||||||
|
api_responses_for_sub_batch = self.fetch_online_logprobs_sglang(
|
||||||
|
input_ids_for_api_call, labels_for_api_call
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
api_responses_for_sub_batch = self.fetch_online_logprobs_vllm(
|
||||||
|
input_ids_for_api_call, labels_for_api_call
|
||||||
|
)
|
||||||
|
|
||||||
|
# api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask"
|
||||||
|
# Each value is a list, corresponding to items_for_api_call
|
||||||
|
for i, item_to_update in enumerate(items_for_api_call):
|
||||||
|
# TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly.
|
||||||
|
if api_responses_for_sub_batch and i < len(
|
||||||
|
api_responses_for_sub_batch["target_token_ids"]
|
||||||
|
): # Check bounds
|
||||||
|
assert len(
|
||||||
|
api_responses_for_sub_batch["target_token_ids"][i]
|
||||||
|
) == len(item_to_update["input_ids"])
|
||||||
|
assert len(
|
||||||
|
api_responses_for_sub_batch["target_logprobs"][i]
|
||||||
|
) == len(item_to_update["input_ids"])
|
||||||
|
assert len(
|
||||||
|
api_responses_for_sub_batch["target_mask"][i]
|
||||||
|
) == len(item_to_update["labels"])
|
||||||
|
item_to_update["target_token_ids"] = (
|
||||||
|
api_responses_for_sub_batch["target_token_ids"][i]
|
||||||
|
)
|
||||||
|
item_to_update["target_logprobs"] = api_responses_for_sub_batch[
|
||||||
|
"target_logprobs"
|
||||||
|
][i]
|
||||||
|
item_to_update["target_mask"] = api_responses_for_sub_batch[
|
||||||
|
"target_mask"
|
||||||
|
][i]
|
||||||
|
else:
|
||||||
|
# API call failed for this item, or response was shorter than expected.
|
||||||
|
# Ensure KD fields are initialized as empty lists.
|
||||||
|
LOG.warning(
|
||||||
|
f" (index {i}), or API response was too short. "
|
||||||
|
f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}"
|
||||||
|
)
|
||||||
|
item_to_update.setdefault("target_token_ids", [])
|
||||||
|
item_to_update.setdefault("target_logprobs", [])
|
||||||
|
item_to_update.setdefault("target_mask", [])
|
||||||
|
|
||||||
|
return super().__call__(features, return_tensors=return_tensors)
|
||||||
485
src/axolotl/integrations/kd/kernels/liger.py
Normal file
485
src/axolotl/integrations/kd/kernels/liger.py
Normal file
@@ -0,0 +1,485 @@
|
|||||||
|
"""
|
||||||
|
Liger Kernels for Chunked Top-K Log-Prob Distillation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from liger_kernel.chunked_loss.fused_linear_distillation import (
|
||||||
|
LigerFusedLinearDistillationBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.integrations.kd.utils import normalize_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||||
|
"""
|
||||||
|
Chunked kl-div loss for top-k logprobs
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def distillation_loss_fn(
|
||||||
|
student_logits_temp_scaled: torch.Tensor, # [chunk_size, vocab_size], already temp-scaled
|
||||||
|
target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k]
|
||||||
|
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
|
||||||
|
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
|
||||||
|
beta: float = 0.0,
|
||||||
|
normalize_topk: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute Top-K KL divergence loss for a chunk.
|
||||||
|
Args:
|
||||||
|
student_logits_temp_scaled: Student logits, scaled by temperature. Shape: (N, V).
|
||||||
|
target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K).
|
||||||
|
target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K).
|
||||||
|
target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K).
|
||||||
|
beta: Controls the type of KL divergence.
|
||||||
|
0.0 for Forward KL (P_teacher || P_student).
|
||||||
|
1.0 for Reverse KL (P_student || P_teacher).
|
||||||
|
0.5 for Symmetric KL (average of Forward and Reverse).
|
||||||
|
normalize_topk: Whether to normalize the log probabilities
|
||||||
|
Returns:
|
||||||
|
Sum of KL divergence losses for the chunk.
|
||||||
|
"""
|
||||||
|
topk = target_token_ids_chunk.shape[-1]
|
||||||
|
student_logits_temp_scaled = ( # [chunk_size, vocab_size]
|
||||||
|
student_logits_temp_scaled.float()
|
||||||
|
)
|
||||||
|
target_logprobs_chunk = target_logprobs_chunk.float()
|
||||||
|
|
||||||
|
# Gather student logits for the top-k teacher token IDs
|
||||||
|
# target_token_ids_chunk: [chunk_size, top_k]
|
||||||
|
# student_logits_topk_temp_scaled: [chunk_size, top_k]
|
||||||
|
student_logits_topk_temp_scaled = torch.gather(
|
||||||
|
student_logits_temp_scaled, dim=-1, index=target_token_ids_chunk
|
||||||
|
)
|
||||||
|
|
||||||
|
# Student log-probabilities for the gathered top-k tokens
|
||||||
|
student_lse = torch.logsumexp(
|
||||||
|
student_logits_temp_scaled, dim=-1, keepdim=True
|
||||||
|
) # [chunk_size, 1]
|
||||||
|
student_logprobs_topk_temp_scaled = (
|
||||||
|
student_logits_topk_temp_scaled - student_lse
|
||||||
|
)
|
||||||
|
|
||||||
|
# we have the top-k student logprobs, normalize them
|
||||||
|
if normalize_topk:
|
||||||
|
student_logprobs_topk_temp_scaled = normalize_logprobs(
|
||||||
|
student_logprobs_topk_temp_scaled, topk
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
|
||||||
|
|
||||||
|
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
|
||||||
|
teacher_logprobs_valid = target_logprobs_chunk[valid_mask]
|
||||||
|
|
||||||
|
# Teacher probabilities P(y|x_teacher) from logprobs
|
||||||
|
# target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))
|
||||||
|
teacher_probs_valid = teacher_logprobs_valid.exp()
|
||||||
|
# Student probabilities P_student from log P_student
|
||||||
|
student_probs_topk_valid = student_logprobs_topk_valid.exp()
|
||||||
|
|
||||||
|
# kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
|
||||||
|
|
||||||
|
# KL divergence: sum(P_teacher * (log P_teacher - log P_student))
|
||||||
|
# = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
|
||||||
|
# The distillation loss is often formulated as -sum(P_teacher * log P_student)
|
||||||
|
# or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
|
||||||
|
# Here, target_logprobs_valid are log_softmax_teacher.
|
||||||
|
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
|
||||||
|
if beta == 0.0: # Contribution from Forward KL
|
||||||
|
fwd_kl_per_token = teacher_probs_valid * (
|
||||||
|
teacher_logprobs_valid - student_logprobs_topk_valid
|
||||||
|
)
|
||||||
|
kd_loss = fwd_kl_per_token.sum()
|
||||||
|
elif beta == 1.0: # Contribution from Reverse KL
|
||||||
|
rev_kl_per_token = student_probs_topk_valid * (
|
||||||
|
student_logprobs_topk_valid - teacher_logprobs_valid
|
||||||
|
)
|
||||||
|
kd_loss = rev_kl_per_token.sum()
|
||||||
|
else:
|
||||||
|
# JSD - Jensen-Shannon Divergence / Symmetric
|
||||||
|
mean_probs = (
|
||||||
|
1 - beta
|
||||||
|
) * student_probs_topk_valid + beta * teacher_probs_valid
|
||||||
|
log_mean_probs = mean_probs.log()
|
||||||
|
student_kl = F.kl_div(
|
||||||
|
log_mean_probs,
|
||||||
|
student_logprobs_topk_valid,
|
||||||
|
reduction="sum",
|
||||||
|
log_target=True,
|
||||||
|
)
|
||||||
|
teacher_kl = F.kl_div(
|
||||||
|
log_mean_probs, teacher_logprobs_valid, reduction="sum", log_target=True
|
||||||
|
)
|
||||||
|
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
||||||
|
kd_loss = jsd_loss
|
||||||
|
|
||||||
|
return kd_loss
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_loss_kl_topk(
|
||||||
|
student_input_chunk: torch.Tensor,
|
||||||
|
student_weight: torch.Tensor,
|
||||||
|
# Args for student_bias, target_token_ids_chunk etc. are passed to the lambda wrapped by grad_and_value
|
||||||
|
# or through `partial`. Let's make them explicit here for clarity.
|
||||||
|
target_token_ids_chunk: torch.Tensor,
|
||||||
|
target_logprobs_chunk: torch.Tensor,
|
||||||
|
target_mask_chunk: torch.Tensor,
|
||||||
|
target_chunk: torch.Tensor, # For hard loss (true labels)
|
||||||
|
student_bias: torch.Tensor = None, # This will be one of the grad targets
|
||||||
|
# Other params passed via `partial` from `forward`
|
||||||
|
distillation_loss_fn=None,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
weight_hard_loss: float = 0.5,
|
||||||
|
weight_soft_loss: float = 0.5,
|
||||||
|
compute_ce_loss: bool = True,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
beta: float = 0.0,
|
||||||
|
normalize_topk: bool = True,
|
||||||
|
):
|
||||||
|
# Compute student logits for the chunk from hidden states and LM head
|
||||||
|
# student_input_chunk: [chunk_size, hidden_dim]
|
||||||
|
# student_lm_head_weight: [vocab_size, hidden_dim]
|
||||||
|
# student_logits_chunk: [chunk_size, vocab_size]
|
||||||
|
student_logits_chunk = F.linear(
|
||||||
|
student_input_chunk, student_weight, student_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
ce_loss = torch.tensor(
|
||||||
|
0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
|
||||||
|
)
|
||||||
|
if compute_ce_loss and weight_hard_loss > 0.0:
|
||||||
|
ce_loss = F.cross_entropy(
|
||||||
|
student_logits_chunk.view(-1, student_logits_chunk.shape[-1]),
|
||||||
|
target_chunk.view(-1),
|
||||||
|
reduction="sum",
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
soft_loss = torch.tensor(
|
||||||
|
0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
|
||||||
|
)
|
||||||
|
if weight_soft_loss > 0.0:
|
||||||
|
student_logits_chunk_temp_scaled = student_logits_chunk / temperature
|
||||||
|
|
||||||
|
# Assuming student_weight.shape[0] (vocab_size) is adequate for target_token_ids_chunk.max()
|
||||||
|
# No explicit padding here; user must ensure vocab alignment or pre-pad student_weight.
|
||||||
|
|
||||||
|
soft_loss = distillation_loss_fn(
|
||||||
|
student_logits_chunk_temp_scaled,
|
||||||
|
target_token_ids_chunk,
|
||||||
|
target_logprobs_chunk,
|
||||||
|
target_mask_chunk,
|
||||||
|
beta=beta,
|
||||||
|
normalize_topk=normalize_topk,
|
||||||
|
)
|
||||||
|
|
||||||
|
return soft_loss, ce_loss
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def forward(
|
||||||
|
cls,
|
||||||
|
ctx,
|
||||||
|
student_input: torch.Tensor, # [batch_size, seq_len, dim]
|
||||||
|
student_lm_head_weight: torch.Tensor, # [dim, vocab_size]
|
||||||
|
target_token_ids: torch.Tensor, # [batch_size, seq_len, top_k]
|
||||||
|
target_logprobs: torch.Tensor, # [batch_size, seq_len, top_k]
|
||||||
|
target_mask: torch.Tensor, # [batch_size, seq_len, top_k]
|
||||||
|
true_labels: torch.Tensor, # [batch_size, seq_len]
|
||||||
|
student_lm_head_bias: torch.Tensor = None,
|
||||||
|
weight_hard_loss: float = 0.5,
|
||||||
|
weight_soft_loss: float = 0.5,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
beta: float = 0.0,
|
||||||
|
compiled: bool = False,
|
||||||
|
chunk_size: int = 1024,
|
||||||
|
compute_ce_loss: bool = True,
|
||||||
|
normalize_topk: bool = True,
|
||||||
|
):
|
||||||
|
CHUNK_SIZE = chunk_size # pylint: disable=invalid-name
|
||||||
|
grad_weight_acc = torch.zeros_like(student_lm_head_weight)
|
||||||
|
grad_inputs_list = []
|
||||||
|
grad_bias_acc = (
|
||||||
|
torch.zeros_like(student_lm_head_bias)
|
||||||
|
if student_lm_head_bias is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
kd_loss_acc = torch.zeros(
|
||||||
|
(), device=student_input.device, dtype=student_input.dtype
|
||||||
|
)
|
||||||
|
ce_loss_acc = torch.zeros(
|
||||||
|
(), device=student_input.device, dtype=student_input.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# This function will be what torch.func.grad_and_value differentiates.
|
||||||
|
# It takes student_input_chunk, student_weight (full), student_bias (full) as primals.
|
||||||
|
# Other necessary data (target_*, etc.) are passed as non-differentiable arguments.
|
||||||
|
def loss_fn_for_grad(
|
||||||
|
_student_input_chunk,
|
||||||
|
_student_lm_head_weight, # full weight
|
||||||
|
_student_lm_head_bias, # full bias
|
||||||
|
# Fixed arguments for a given chunk, not differentiated:
|
||||||
|
_target_token_ids_chunk,
|
||||||
|
_target_logprobs_chunk,
|
||||||
|
_target_mask_chunk,
|
||||||
|
_true_labels_chunk,
|
||||||
|
):
|
||||||
|
return cls._compute_loss_kl_topk(
|
||||||
|
student_input_chunk=_student_input_chunk,
|
||||||
|
student_weight=_student_lm_head_weight,
|
||||||
|
target_token_ids_chunk=_target_token_ids_chunk,
|
||||||
|
target_logprobs_chunk=_target_logprobs_chunk,
|
||||||
|
target_mask_chunk=_target_mask_chunk,
|
||||||
|
target_chunk=_true_labels_chunk,
|
||||||
|
student_bias=_student_lm_head_bias,
|
||||||
|
distillation_loss_fn=cls.distillation_loss_fn,
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
weight_hard_loss=weight_hard_loss,
|
||||||
|
weight_soft_loss=weight_soft_loss,
|
||||||
|
compute_ce_loss=compute_ce_loss,
|
||||||
|
temperature=temperature,
|
||||||
|
beta=beta,
|
||||||
|
normalize_topk=normalize_topk,
|
||||||
|
)
|
||||||
|
|
||||||
|
def accumulate_chunk_grads(
|
||||||
|
student_input_chunk_ac,
|
||||||
|
target_token_ids_chunk_ac,
|
||||||
|
target_logprobs_chunk_ac,
|
||||||
|
target_mask_chunk_ac,
|
||||||
|
true_labels_chunk_ac,
|
||||||
|
):
|
||||||
|
# student_weight and student_bias are closed over from the outer scope (full tensors)
|
||||||
|
if student_lm_head_bias is not None:
|
||||||
|
(
|
||||||
|
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
||||||
|
(chunk_kd_loss, chunk_ce_loss),
|
||||||
|
) = torch.func.grad_and_value(
|
||||||
|
loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True
|
||||||
|
)(
|
||||||
|
student_input_chunk_ac,
|
||||||
|
student_lm_head_weight,
|
||||||
|
student_lm_head_bias, # primals
|
||||||
|
target_token_ids_chunk_ac,
|
||||||
|
target_logprobs_chunk_ac,
|
||||||
|
target_mask_chunk_ac,
|
||||||
|
true_labels_chunk_ac,
|
||||||
|
) # non-primals
|
||||||
|
grad_bias_acc.add_(chunk_grad_bias)
|
||||||
|
else:
|
||||||
|
argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight
|
||||||
|
(
|
||||||
|
(chunk_grad_input, chunk_grad_weight), # No grad for bias
|
||||||
|
(chunk_kd_loss, chunk_ce_loss),
|
||||||
|
) = torch.func.grad_and_value(
|
||||||
|
loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True
|
||||||
|
)(
|
||||||
|
student_input_chunk_ac,
|
||||||
|
student_lm_head_weight,
|
||||||
|
None, # Pass None for student_bias primal
|
||||||
|
target_token_ids_chunk_ac,
|
||||||
|
target_logprobs_chunk_ac,
|
||||||
|
target_mask_chunk_ac,
|
||||||
|
true_labels_chunk_ac,
|
||||||
|
)
|
||||||
|
|
||||||
|
grad_weight_acc.add_(chunk_grad_weight)
|
||||||
|
kd_loss_acc.add_(chunk_kd_loss)
|
||||||
|
ce_loss_acc.add_(chunk_ce_loss)
|
||||||
|
|
||||||
|
return chunk_grad_input
|
||||||
|
|
||||||
|
if compiled:
|
||||||
|
accumulate_chunk_grads_compiled = torch.compile(
|
||||||
|
accumulate_chunk_grads, dynamic=True, backend="inductor"
|
||||||
|
) # dynamic=True often helpful
|
||||||
|
else:
|
||||||
|
accumulate_chunk_grads_compiled = accumulate_chunk_grads
|
||||||
|
|
||||||
|
# Use the same chunking logic as LigerFusedLinearDistillationBase.forward
|
||||||
|
B, N, D = student_input.shape # pylint: disable=invalid-name
|
||||||
|
K = target_token_ids.shape[-1] # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
student_input_flat = student_input.reshape(-1, student_input.shape[-1])
|
||||||
|
target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1])
|
||||||
|
target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1])
|
||||||
|
target_mask_flat = target_mask.reshape(-1, target_mask.shape[-1])
|
||||||
|
# pad and shift for cross entropy loss
|
||||||
|
true_labels = torch.nn.functional.pad(true_labels, (0, 1), value=ignore_index)
|
||||||
|
true_labels_flat = true_labels[:, 1:].contiguous().view(-1)
|
||||||
|
|
||||||
|
num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE)
|
||||||
|
|
||||||
|
_student_input_chunks = torch.chunk(
|
||||||
|
student_input_flat, chunks=num_chunks, dim=0
|
||||||
|
)
|
||||||
|
_target_token_ids_chunks = torch.chunk(
|
||||||
|
target_token_ids_flat, chunks=num_chunks, dim=0
|
||||||
|
)
|
||||||
|
_target_logprobs_chunks = torch.chunk(
|
||||||
|
target_logprobs_flat, chunks=num_chunks, dim=0
|
||||||
|
)
|
||||||
|
_target_mask_chunks = torch.chunk(target_mask_flat, chunks=num_chunks, dim=0)
|
||||||
|
_true_labels_chunks = torch.chunk(true_labels_flat, chunks=num_chunks, dim=0)
|
||||||
|
|
||||||
|
for i in range(num_chunks):
|
||||||
|
grad_input_chunk = accumulate_chunk_grads_compiled(
|
||||||
|
_student_input_chunks[i],
|
||||||
|
_target_token_ids_chunks[i],
|
||||||
|
_target_logprobs_chunks[i],
|
||||||
|
_target_mask_chunks[i],
|
||||||
|
_true_labels_chunks[i],
|
||||||
|
)
|
||||||
|
grad_inputs_list.append(grad_input_chunk)
|
||||||
|
|
||||||
|
grad_inputs_combined = torch.cat(grad_inputs_list, dim=0)
|
||||||
|
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
|
||||||
|
|
||||||
|
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
|
||||||
|
ctx.hyperparams_count = 9 # Corresponds to number of hyperparams after main tensors in fwd signature
|
||||||
|
ctx.bias_was_none = student_lm_head_bias is None
|
||||||
|
ctx.orig_dims = (B, N, D, K)
|
||||||
|
|
||||||
|
# since this is packed, there is simply a single batch, so batchmean reduction of kl-div is simply the accumulated sum
|
||||||
|
# we still need to scale the kd_loss by the temp^2
|
||||||
|
kd_loss_acc = kd_loss_acc * (temperature**2)
|
||||||
|
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
|
||||||
|
|
||||||
|
return final_loss
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
grad_input_flat, grad_weight, grad_bias_maybe = (
|
||||||
|
ctx.saved_tensors
|
||||||
|
) # grad_input_flat is (B*N, D)
|
||||||
|
|
||||||
|
# Scale gradients by grad_output if it's not 1.0
|
||||||
|
if not torch.equal(
|
||||||
|
grad_output,
|
||||||
|
torch.tensor(1.0, device=grad_output.device, dtype=grad_output.dtype),
|
||||||
|
):
|
||||||
|
grad_input_flat = grad_input_flat * grad_output
|
||||||
|
grad_weight = grad_weight * grad_output
|
||||||
|
if grad_bias_maybe is not None:
|
||||||
|
grad_bias_maybe = grad_bias_maybe * grad_output
|
||||||
|
|
||||||
|
# Reshape grad_input_flat to match original student_input shape (B, N, D)
|
||||||
|
# ctx.orig_dims stores (B, N, D, K)
|
||||||
|
# We need the first three dimensions for student_input's shape.
|
||||||
|
# Ensure that orig_dims are not (0,0,0,K) for empty inputs leading to view errors
|
||||||
|
if (
|
||||||
|
ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
|
||||||
|
and grad_input_flat.numel() == 0
|
||||||
|
):
|
||||||
|
# If original input was empty, gradient should also be empty with correct shape
|
||||||
|
grad_input_reshaped = torch.zeros(
|
||||||
|
ctx.orig_dims[0],
|
||||||
|
ctx.orig_dims[1],
|
||||||
|
ctx.orig_dims[2],
|
||||||
|
dtype=grad_input_flat.dtype,
|
||||||
|
device=grad_input_flat.device,
|
||||||
|
)
|
||||||
|
elif grad_input_flat.numel() == 0 and not (
|
||||||
|
ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
|
||||||
|
):
|
||||||
|
# This case should ideally not happen if forward path is correct (non-empty input -> non-empty flat grad)
|
||||||
|
# but as a safeguard:
|
||||||
|
grad_input_reshaped = torch.zeros(
|
||||||
|
ctx.orig_dims[0],
|
||||||
|
ctx.orig_dims[1],
|
||||||
|
ctx.orig_dims[2],
|
||||||
|
dtype=grad_input_flat.dtype,
|
||||||
|
device=grad_input_flat.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
grad_input_reshaped = grad_input_flat.view(
|
||||||
|
ctx.orig_dims[0], ctx.orig_dims[1], ctx.orig_dims[2]
|
||||||
|
)
|
||||||
|
|
||||||
|
nones_for_hyperparams = [None] * ctx.hyperparams_count
|
||||||
|
grad_bias_return = grad_bias_maybe if not ctx.bias_was_none else None
|
||||||
|
|
||||||
|
return (
|
||||||
|
grad_input_reshaped, # Gradient for student_input (reshaped)
|
||||||
|
grad_weight, # Gradient for student_lm_head_weight
|
||||||
|
None, # Gradient for target_token_ids
|
||||||
|
None, # Gradient for target_logprobs
|
||||||
|
None, # Gradient for target_mask
|
||||||
|
None, # Gradient for true_labels
|
||||||
|
grad_bias_return, # Gradient for student_lm_head_bias
|
||||||
|
*nones_for_hyperparams, # Grads for weight_hard_loss, ..., compute_ce_loss
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
wrapper for chunked top-k logprob kl-d
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight_hard_loss: float = 0.5,
|
||||||
|
weight_soft_loss: float = 0.5,
|
||||||
|
temperature: float = 1.0, # This is the kd_temperature
|
||||||
|
beta: float = 1.0,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
compiled: bool = True,
|
||||||
|
chunk_size: int = 1024,
|
||||||
|
compute_ce_loss: bool = True,
|
||||||
|
normalize_topk: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0):
|
||||||
|
raise ValueError("Loss weights must be between 0.0 and 1.0.")
|
||||||
|
if temperature <= 0:
|
||||||
|
raise ValueError("Temperature must be positive.")
|
||||||
|
|
||||||
|
self.weight_hard_loss = weight_hard_loss
|
||||||
|
self.weight_soft_loss = weight_soft_loss
|
||||||
|
self.temperature = temperature
|
||||||
|
self.beta = beta
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
self.compiled = compiled
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.compute_ce_loss = compute_ce_loss
|
||||||
|
self.normalize_topk = normalize_topk
|
||||||
|
|
||||||
|
if not self.compute_ce_loss and self.weight_hard_loss > 0.0:
|
||||||
|
print(
|
||||||
|
f"Warning: compute_ce_loss is False, but weight_hard_loss ({self.weight_hard_loss}) > 0. Hard loss will effectively be zero."
|
||||||
|
)
|
||||||
|
# self.weight_hard_loss = 0.0 # Or let user manage this
|
||||||
|
if self.weight_soft_loss == 0.0:
|
||||||
|
print(
|
||||||
|
"Warning: weight_soft_loss is 0.0. Soft (KD) loss will not be computed."
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
lm_head_weight: torch.Tensor, # Weights of the linear layer in the LM head
|
||||||
|
student_hidden_states: torch.Tensor, # student_hidden_states before the lm_head
|
||||||
|
target_token_ids: torch.Tensor,
|
||||||
|
target_logprobs: torch.Tensor,
|
||||||
|
target_mask: torch.Tensor,
|
||||||
|
true_labels: torch.Tensor,
|
||||||
|
student_bias: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return LigerFusedLinearKLTopKLogprobFunction.apply(
|
||||||
|
student_hidden_states,
|
||||||
|
lm_head_weight,
|
||||||
|
target_token_ids,
|
||||||
|
target_logprobs,
|
||||||
|
target_mask,
|
||||||
|
true_labels,
|
||||||
|
student_bias,
|
||||||
|
self.weight_hard_loss,
|
||||||
|
self.weight_soft_loss,
|
||||||
|
self.ignore_index,
|
||||||
|
self.temperature,
|
||||||
|
self.beta,
|
||||||
|
self.compiled,
|
||||||
|
self.chunk_size,
|
||||||
|
self.compute_ce_loss,
|
||||||
|
self.normalize_topk,
|
||||||
|
)
|
||||||
97
src/axolotl/integrations/kd/kernels/models.py
Normal file
97
src/axolotl/integrations/kd/kernels/models.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""
|
||||||
|
model patcher for chunked top-k kl-div
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Union, Unpack
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import Cache
|
||||||
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.utils import LossKwargs
|
||||||
|
|
||||||
|
|
||||||
|
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
|
||||||
|
"""
|
||||||
|
placeholder kwargs for hf model classes
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def kldiv_forward_llama_like(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
target_logprobs: Optional[torch.Tensor] = None,
|
||||||
|
target_token_ids: Optional[torch.LongTensor] = None,
|
||||||
|
target_mask: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Cache] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument
|
||||||
|
**kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc]
|
||||||
|
) -> CausalLMOutputWithPast:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs.last_hidden_state
|
||||||
|
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
|
||||||
|
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
|
||||||
|
|
||||||
|
loss = self.loss_function(
|
||||||
|
self.lm_head.weight,
|
||||||
|
hidden_states,
|
||||||
|
target_token_ids,
|
||||||
|
target_logprobs,
|
||||||
|
target_mask,
|
||||||
|
true_labels=labels,
|
||||||
|
)
|
||||||
|
num_items_in_batch = kwargs.pop("num_items_in_batch", -1)
|
||||||
|
if num_items_in_batch is not None and num_items_in_batch > 0:
|
||||||
|
loss = loss / num_items_in_batch
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=None,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_kernel(model_type):
|
||||||
|
# Dynamically import the module and attention class
|
||||||
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
|
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
|
||||||
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
|
||||||
|
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
||||||
|
model_cls.forward = kldiv_forward_llama_like
|
||||||
@@ -16,40 +16,7 @@
|
|||||||
loss for top_k KL divergence
|
loss for top_k KL divergence
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
def zscore_standardize(
|
|
||||||
logits: torch.Tensor,
|
|
||||||
mask: torch.Tensor = None,
|
|
||||||
base_temperature: float = 1.0,
|
|
||||||
eps: float = 1e-9,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Z-score standardize along the last dimension of `logits`.
|
|
||||||
i.e., for each [B, seq_len] row, across K entries:
|
|
||||||
z = (logits - mean) / std,
|
|
||||||
then scale by 1 / base_temperature if desired.
|
|
||||||
|
|
||||||
mask can be broadcastable or None. If None, we standardize all elements.
|
|
||||||
"""
|
|
||||||
if mask is None:
|
|
||||||
# shape: [B, seq_len, K]
|
|
||||||
# Mean and std over dim=-1
|
|
||||||
mean = logits.mean(dim=-1, keepdim=True)
|
|
||||||
var = logits.var(dim=-1, unbiased=False, keepdim=True)
|
|
||||||
else:
|
|
||||||
# If you have to exclude some tokens, multiply by mask, etc.
|
|
||||||
float_mask = mask.to(logits.dtype)
|
|
||||||
count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0)
|
|
||||||
mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count
|
|
||||||
var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count
|
|
||||||
|
|
||||||
std = torch.sqrt(var.clamp_min(eps))
|
|
||||||
z = (logits - mean) / std
|
|
||||||
|
|
||||||
# Scale by 1 / base_temperature
|
|
||||||
z = z / base_temperature
|
|
||||||
return z
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
@@ -60,7 +27,6 @@ def loss(
|
|||||||
target_mask: torch.Tensor,
|
target_mask: torch.Tensor,
|
||||||
num_items_in_batch: int = -1, # Use -1 to indicate "None"
|
num_items_in_batch: int = -1, # Use -1 to indicate "None"
|
||||||
kd_temperature: float = 1.0,
|
kd_temperature: float = 1.0,
|
||||||
top_k_before_softmax: int = 0,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
A KD loss function that is TorchScript-friendly.
|
A KD loss function that is TorchScript-friendly.
|
||||||
@@ -77,8 +43,6 @@ def loss(
|
|||||||
num_items_in_batch (int, optional): The number of items in the batch.
|
num_items_in_batch (int, optional): The number of items in the batch.
|
||||||
kd_temperature (float, optional): The temperature for KD.
|
kd_temperature (float, optional): The temperature for KD.
|
||||||
Default: 1.0
|
Default: 1.0
|
||||||
top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits
|
|
||||||
Default: 0
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
target_logprobs = target_logprobs.float()
|
target_logprobs = target_logprobs.float()
|
||||||
@@ -88,46 +52,24 @@ def loss(
|
|||||||
# student_logits shape: [B, student_seq_len, vocab_size]
|
# student_logits shape: [B, student_seq_len, vocab_size]
|
||||||
teacher_seq_len = target_token_ids.shape[1]
|
teacher_seq_len = target_token_ids.shape[1]
|
||||||
|
|
||||||
if top_k_before_softmax:
|
# Slice student logits to match teacher-provided sequence length
|
||||||
# Slice student logits to match teacher-provided sequence length
|
student_logits_for_kd = (
|
||||||
student_logits_for_kd = student_logits[
|
student_logits[:, :teacher_seq_len, :] / kd_temperature
|
||||||
:, :teacher_seq_len, :
|
) # [B, teacher_seq_len, vocab_size]
|
||||||
] # [B, teacher_seq_len, vocab_size]
|
|
||||||
|
|
||||||
# Gather student logits for teacher's top-K tokens
|
# keep in full precision for numerical stability of loss
|
||||||
student_logits_topk = torch.gather(
|
student_logits_for_kd = student_logits_for_kd.float()
|
||||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
|
||||||
) # [B, teacher_seq_len, K]
|
|
||||||
|
|
||||||
student_logits_topk = student_logits_topk.float()
|
# Gather student logits for teacher's top-K tokens
|
||||||
|
student_logits_topk = torch.gather(
|
||||||
|
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||||
|
) # [B, teacher_seq_len, K]
|
||||||
|
|
||||||
# Apply KD temperature to student’s logits
|
# Compute logsumexp across full vocabulary
|
||||||
if kd_temperature != 1.0:
|
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
|
||||||
student_logits_topk = student_logits_topk / kd_temperature
|
|
||||||
|
|
||||||
# Convert student top-k logits to logprobs
|
# Convert just the top-k logits to logprobs
|
||||||
student_logprobs_topk = student_logits_topk - torch.logsumexp(
|
student_logprobs_topk = student_logits_topk - student_lse
|
||||||
student_logits_topk, dim=-1, keepdim=True
|
|
||||||
) # [B, teacher_seq_len, K]
|
|
||||||
else:
|
|
||||||
# Slice student logits to match teacher-provided sequence length
|
|
||||||
student_logits_for_kd = (
|
|
||||||
student_logits[:, :teacher_seq_len, :] / kd_temperature
|
|
||||||
) # [B, teacher_seq_len, vocab_size]
|
|
||||||
|
|
||||||
# keep in full precision for numerical stability of loss
|
|
||||||
student_logits_for_kd = student_logits_for_kd.float()
|
|
||||||
|
|
||||||
# Gather student logits for teacher's top-K tokens
|
|
||||||
student_logits_topk = torch.gather(
|
|
||||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
|
||||||
) # [B, teacher_seq_len, K]
|
|
||||||
|
|
||||||
# Compute logsumexp across full vocabulary
|
|
||||||
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# Convert just the top-k logits to logprobs
|
|
||||||
student_logprobs_topk = student_logits_topk - student_lse
|
|
||||||
|
|
||||||
# Convert teacher_mask to boolean for indexing
|
# Convert teacher_mask to boolean for indexing
|
||||||
# In TorchScript, .bool() is sometimes unsupported, so we do:
|
# In TorchScript, .bool() is sometimes unsupported, so we do:
|
||||||
@@ -144,10 +86,6 @@ def loss(
|
|||||||
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
|
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
|
||||||
kd_loss = kd_loss_per_token.sum()
|
kd_loss = kd_loss_per_token.sum()
|
||||||
|
|
||||||
# Multiply by T^2 (classical KD scaling)
|
|
||||||
if kd_temperature != 1.0:
|
|
||||||
kd_loss = kd_loss * (kd_temperature**2)
|
|
||||||
|
|
||||||
# Normalize by number of items (if provided) or by valid tokens
|
# Normalize by number of items (if provided) or by valid tokens
|
||||||
if num_items_in_batch > 0:
|
if num_items_in_batch > 0:
|
||||||
kd_loss = kd_loss / float(num_items_in_batch)
|
kd_loss = kd_loss / float(num_items_in_batch)
|
||||||
@@ -158,80 +96,74 @@ def loss(
|
|||||||
return kd_loss
|
return kd_loss
|
||||||
|
|
||||||
|
|
||||||
def topk_kd_loss_with_zscore(
|
class ChunkedTopKKDLoss(nn.Module):
|
||||||
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
|
|
||||||
target_token_ids: torch.Tensor, # [B, seq_len, K]
|
|
||||||
target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
|
|
||||||
target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
|
|
||||||
kd_temperature: float = 1.0, # classic KD temperature
|
|
||||||
zscore_base_temp: float = 1.0, # from the paper
|
|
||||||
num_items_in_batch: int = -1,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
A variant of top_k KL divergence with Z-score scaling
|
A wrapper that chunks (splits) the student and teacher outputs along the time dimension
|
||||||
from "Logit Standardization in Knowledge Distillation".
|
to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies.
|
||||||
|
|
||||||
|
Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to top-K teacher logprobs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
target_logprobs = target_logprobs.float()
|
def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.num_output_chunks = num_output_chunks
|
||||||
|
self.kd_temperature = kd_temperature
|
||||||
|
|
||||||
B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
|
def forward(
|
||||||
# 1) Gather the student's top-k logits to match teacher
|
self,
|
||||||
student_logits_for_kd = student_logits[
|
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
|
||||||
:, :teacher_seq_len, :
|
target_token_ids: torch.Tensor, # [B, seq_len, K]
|
||||||
] # [B, seq_len, vocab]
|
target_logprobs: torch.Tensor, # [B, seq_len, K]
|
||||||
student_topk_logits = torch.gather(
|
target_mask: torch.Tensor, # [B, seq_len, K]
|
||||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
num_items_in_batch: int = -1, # optional batch size for normalization
|
||||||
) # [B, seq_len, K]
|
) -> torch.Tensor:
|
||||||
|
|
||||||
student_topk_logits = student_topk_logits.float()
|
# 1. Split along the "token" dimension (dim=1).
|
||||||
|
student_logits_chunks = student_logits.chunk(self.num_output_chunks, dim=1)
|
||||||
|
token_ids_chunks = target_token_ids.chunk(self.num_output_chunks, dim=1)
|
||||||
|
logprobs_chunks = target_logprobs.chunk(self.num_output_chunks, dim=1)
|
||||||
|
mask_chunks = target_mask.chunk(self.num_output_chunks, dim=1)
|
||||||
|
|
||||||
# 2) If you want to keep the "classical" T scaling, apply it first
|
# We'll accumulate a global "sum of losses" and "sum of valid tokens"
|
||||||
if kd_temperature != 1.0:
|
# so that our final average is consistent with the entire sequence/batch.
|
||||||
student_topk_logits = student_topk_logits / kd_temperature
|
total_loss = 0.0
|
||||||
|
total_valid_tokens = 0
|
||||||
|
|
||||||
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
|
# 2. Loop over each chunk and compute a chunk-specific loss.
|
||||||
# (They differ by +some_constant from real logits, but in z-score
|
for st_chunk, tid_chunk, lp_chunk, msk_chunk in zip(
|
||||||
# that constant is subtracted out anyway.)
|
student_logits_chunks, token_ids_chunks, logprobs_chunks, mask_chunks
|
||||||
teacher_logits_for_zscore = target_logprobs # rename variable for clarity
|
):
|
||||||
|
# We pass num_items_in_batch=-1 so that the kd_loss
|
||||||
|
# will average over *this chunk's* valid tokens only.
|
||||||
|
chunk_loss = loss(
|
||||||
|
student_logits=st_chunk,
|
||||||
|
target_token_ids=tid_chunk,
|
||||||
|
target_logprobs=lp_chunk,
|
||||||
|
target_mask=msk_chunk,
|
||||||
|
num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens
|
||||||
|
kd_temperature=self.kd_temperature,
|
||||||
|
)
|
||||||
|
|
||||||
# 4) Z-score teacher and student
|
# kd_loss returns an average over the chunk's valid tokens.
|
||||||
# If target_mask is 2D, expand to 3D for the K dimension
|
# We want a global average in the end, so we need to re‐weight
|
||||||
if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
|
# by the number of valid tokens in this chunk and keep track of the total.
|
||||||
target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
|
chunk_valid_mask = msk_chunk.to(torch.bool)
|
||||||
|
chunk_valid_count = chunk_valid_mask.sum() # scalar tensor
|
||||||
|
|
||||||
teacher_z = zscore_standardize(
|
# Re-scale "chunk average" back to "chunk sum"
|
||||||
teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
|
chunk_loss_sum = chunk_loss * chunk_valid_count
|
||||||
)
|
|
||||||
student_z = zscore_standardize(
|
|
||||||
student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5) Convert to log-probs for KL
|
total_loss += chunk_loss_sum
|
||||||
teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True)
|
total_valid_tokens += chunk_valid_count
|
||||||
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# 6) Restrict to valid tokens if needed
|
# 3. Normalize *once* at the end.
|
||||||
valid_mask = target_mask.bool() # shape [B, seq_len, K]
|
if num_items_in_batch > 0:
|
||||||
teacher_probs_z = teacher_logprobs_z.exp()
|
# If the user gave us a manual denominator (e.g. total items in batch),
|
||||||
teacher_probs_z = teacher_probs_z[valid_mask]
|
# we divide by it. Typically used if each item is of different length.
|
||||||
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
|
final_loss = total_loss / float(num_items_in_batch)
|
||||||
student_logprobs_z = student_logprobs_z[valid_mask]
|
else:
|
||||||
|
# Otherwise, divide by total valid tokens across all chunks.
|
||||||
|
# to get the same result as a non-chunked approach.
|
||||||
|
final_loss = total_loss / float(total_valid_tokens)
|
||||||
|
|
||||||
# 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] )
|
return final_loss
|
||||||
kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z)
|
|
||||||
kd_loss = kd_loss_per_token.sum()
|
|
||||||
|
|
||||||
# 8) If using classical KD scaling by T^2
|
|
||||||
if kd_temperature != 1.0:
|
|
||||||
kd_loss = kd_loss * (kd_temperature**2)
|
|
||||||
|
|
||||||
# Optionally scale by zscore_base_temp**2 if you want (paper might differ).
|
|
||||||
# kd_loss = kd_loss * (zscore_base_temp**2)
|
|
||||||
|
|
||||||
# 9) Normalize
|
|
||||||
if num_items_in_batch is not None and num_items_in_batch > 0:
|
|
||||||
kd_loss = kd_loss / float(num_items_in_batch)
|
|
||||||
else:
|
|
||||||
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
|
|
||||||
|
|
||||||
return kd_loss
|
|
||||||
|
|||||||
@@ -18,8 +18,7 @@ KD trainer
|
|||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
|
||||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
|
||||||
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKDTrainer(AxolotlTrainer):
|
class AxolotlKDTrainer(AxolotlTrainer):
|
||||||
@@ -27,6 +26,18 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
Custom trainer subclass for Knowledge Distillation (KD)
|
Custom trainer subclass for Knowledge Distillation (KD)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.model_accepts_loss_kwargs = True
|
||||||
|
self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
|
||||||
|
self.args.kd_ce_alpha, # hard label loss
|
||||||
|
self.args.kd_alpha, # kd loss
|
||||||
|
self.args.kd_temperature,
|
||||||
|
self.args.kd_beta,
|
||||||
|
compute_ce_loss=bool(self.args.kd_ce_alpha),
|
||||||
|
normalize_topk=self.args.kd_normalize_topk,
|
||||||
|
)
|
||||||
|
|
||||||
def _set_signature_columns_if_needed(self):
|
def _set_signature_columns_if_needed(self):
|
||||||
super()._set_signature_columns_if_needed()
|
super()._set_signature_columns_if_needed()
|
||||||
columns_to_add = []
|
columns_to_add = []
|
||||||
@@ -52,12 +63,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
|
|
||||||
Subclass and override for custom behavior.
|
Subclass and override for custom behavior.
|
||||||
"""
|
"""
|
||||||
|
if (
|
||||||
target_logprobs = inputs.pop("target_logprobs")
|
self.args.sample_packing
|
||||||
target_token_ids = inputs.pop("target_token_ids")
|
and hasattr(inputs, "attention_mask")
|
||||||
target_mask = inputs.pop("target_mask")
|
and hasattr(inputs, "position_ids")
|
||||||
|
):
|
||||||
seq_len = target_token_ids.shape[1]
|
del inputs["attention_mask"]
|
||||||
|
|
||||||
if self.model_accepts_loss_kwargs:
|
if self.model_accepts_loss_kwargs:
|
||||||
loss_kwargs = {}
|
loss_kwargs = {}
|
||||||
@@ -65,49 +76,4 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||||
inputs = {**inputs, **loss_kwargs}
|
inputs = {**inputs, **loss_kwargs}
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
return outputs[0]
|
||||||
# FIXME: account for tokenizer.padding_side
|
|
||||||
student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
|
|
||||||
|
|
||||||
shift_logits = student_logits.contiguous()
|
|
||||||
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
|
||||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
|
||||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
|
||||||
|
|
||||||
if self.args.kd_zscore_base_temp:
|
|
||||||
loss_kd = topk_kd_loss_with_zscore(
|
|
||||||
shift_logits,
|
|
||||||
target_token_ids_for_loss,
|
|
||||||
target_logprobs_for_loss,
|
|
||||||
target_mask_for_loss,
|
|
||||||
kd_temperature=self.args.kd_temperature,
|
|
||||||
zscore_base_temp=self.args.kd_zscore_base_temp,
|
|
||||||
num_items_in_batch=num_items_in_batch,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
loss_kd = topk_kd_loss(
|
|
||||||
shift_logits,
|
|
||||||
target_token_ids_for_loss,
|
|
||||||
target_logprobs_for_loss,
|
|
||||||
target_mask_for_loss,
|
|
||||||
num_items_in_batch=num_items_in_batch,
|
|
||||||
kd_temperature=self.args.kd_temperature,
|
|
||||||
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.kd_ce_alpha > 0:
|
|
||||||
kd_alpha = self.args.kd_alpha
|
|
||||||
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
|
||||||
else:
|
|
||||||
loss = loss_kd
|
|
||||||
# Save past state if it exists
|
|
||||||
# TODO: this needs to be fixed and made cleaner later.
|
|
||||||
if self.args.past_index >= 0:
|
|
||||||
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.args.past_index
|
|
||||||
]
|
|
||||||
|
|
||||||
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
|
||||||
loss *= self.accelerator.num_processes
|
|
||||||
|
|
||||||
return (loss, outputs) if return_outputs else loss
|
|
||||||
|
|||||||
100
src/axolotl/integrations/kd/utils.py
Normal file
100
src/axolotl/integrations/kd/utils.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""Helper KD utils"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import FloatTensor, Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
|
||||||
|
"""
|
||||||
|
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
|
||||||
|
"""
|
||||||
|
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
|
||||||
|
# This should ideally be handled by the caller ensuring correct padding/truncation first
|
||||||
|
if logprobs.shape[-1] != topk:
|
||||||
|
# pad last dimension of logprobs to match topk length with -inf
|
||||||
|
padding_len = topk - logprobs.shape[-1]
|
||||||
|
padding_tensor = torch.full(
|
||||||
|
(
|
||||||
|
*logprobs.shape[:-1],
|
||||||
|
padding_len,
|
||||||
|
), # Takes all dimensions of logprobs except the last, then appends padding_needed
|
||||||
|
float("-inf"),
|
||||||
|
dtype=logprobs.dtype,
|
||||||
|
device=logprobs.device,
|
||||||
|
)
|
||||||
|
logprobs = torch.cat((logprobs, padding_tensor), dim=-1)
|
||||||
|
|
||||||
|
# Convert logprobs at T_online to probabilities
|
||||||
|
# use log sum exp trick to avoid underflow
|
||||||
|
position_logprobs_lse = torch.logsumexp(logprobs, dim=-1, keepdim=True)
|
||||||
|
teacher_probs_t_online = torch.exp(logprobs - position_logprobs_lse)
|
||||||
|
|
||||||
|
# Normalize probabilities (sum to 1)
|
||||||
|
# This is important if the top-k from server aren't a full distribution
|
||||||
|
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=-1, keepdim=True)
|
||||||
|
teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum
|
||||||
|
|
||||||
|
final_logprobs_tensor = torch.log(teacher_probs_t_online)
|
||||||
|
|
||||||
|
return final_logprobs_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def strided_chunk_views(
|
||||||
|
tensor: Union[np.ndarray, torch.Tensor],
|
||||||
|
chunks: int,
|
||||||
|
dim: int = 0,
|
||||||
|
stride: int = 1,
|
||||||
|
chunk_size: int | None = None,
|
||||||
|
) -> List[Union[np.ndarray, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Split a tensor into chunks along a dimension with striding, prioritizing views over copies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: Input tensor (numpy array or torch tensor)
|
||||||
|
chunks: Number of chunks to create
|
||||||
|
dim: Dimension along which to chunk (default: 0)
|
||||||
|
stride: Stride between chunk starting positions (default: 1)
|
||||||
|
chunk_size: Size of each chunk. If None, calculated automatically (default: None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tensor chunks (views when possible, copies when necessary)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get the size of the specified dimension
|
||||||
|
dim_size = tensor.shape[dim]
|
||||||
|
|
||||||
|
# Calculate chunk size if not provided
|
||||||
|
if chunk_size is None:
|
||||||
|
chunk_size = (dim_size + chunks - 1) // chunks # Ceiling division
|
||||||
|
|
||||||
|
chunks_list = []
|
||||||
|
|
||||||
|
for i in range(chunks):
|
||||||
|
start_idx = i * stride
|
||||||
|
end_idx = min(start_idx + chunk_size, dim_size)
|
||||||
|
|
||||||
|
# Break if we've gone beyond the tensor
|
||||||
|
if start_idx >= dim_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Create slice objects for all dimensions
|
||||||
|
slices = [slice(None)] * tensor.ndim
|
||||||
|
slices[dim] = slice(start_idx, end_idx)
|
||||||
|
|
||||||
|
chunk = tensor[tuple(slices)]
|
||||||
|
chunks_list.append(chunk)
|
||||||
|
|
||||||
|
return chunks_list
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_overlap(input_tensor: Tensor, chunks: int, dim: int = 0, overlap: int = 1):
|
||||||
|
dim_size = input_tensor.shape[dim]
|
||||||
|
stride = math.ceil(dim_size / chunks)
|
||||||
|
|
||||||
|
return strided_chunk_views(
|
||||||
|
input_tensor, chunks, dim, stride=stride, chunk_size=stride + overlap
|
||||||
|
)
|
||||||
@@ -2,16 +2,8 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
import transformers
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from mistral_common.protocol.instruct.messages import SystemMessage, UserMessage
|
|
||||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
|
||||||
from mistral_common.tokens.tokenizers.mistral import (
|
|
||||||
MistralTokenizer,
|
|
||||||
)
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AddedToken,
|
AddedToken,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@@ -31,622 +23,239 @@ from axolotl.utils.logging import get_logger
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||||
|
|
||||||
# Constants
|
|
||||||
LLAMA_TOKENIZER_CLASSES = {
|
|
||||||
"LlamaTokenizer",
|
|
||||||
"LlamaTokenizerFast",
|
|
||||||
"CodeLlamaTokenizer",
|
|
||||||
"CodeLlamaTokenizerFast",
|
|
||||||
}
|
|
||||||
|
|
||||||
FAST_LLAMA_TOKENIZER_CLASSES = {"LlamaTokenizerFast", "CodeLlamaTokenizerFast"}
|
def modify_tokenizer_files(
|
||||||
|
tokenizer_path: str, token_mappings: dict[int, str], output_dir: str
|
||||||
QWEN_DEFAULT_TOKEN = "<|endoftext|>"
|
) -> str:
|
||||||
GPTNEOX_PAD_TOKEN = "[PAD]"
|
|
||||||
CHATML_DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."
|
|
||||||
|
|
||||||
|
|
||||||
class MistralTokenizerWrapper:
|
|
||||||
"""
|
"""
|
||||||
Wrapper to make MistralTokenizer compatible with Hugging Face tokenizer interface.
|
Modify tokenizer files to replace added_tokens strings, save to output directory,
|
||||||
This provides a bridge between Mistral's native tokenizer and axolotl's expectations.
|
and return the path to the modified tokenizer.
|
||||||
|
|
||||||
|
This only works with reserved tokens that were added to the tokenizer, not tokens
|
||||||
|
already part of the vocab.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer_path: Path or name of the original tokenizer
|
||||||
|
token_mappings: Dict mapping {token_id (int): new_token_string}
|
||||||
|
output_dir: Directory to save the modified tokenizer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the modified tokenizer directory
|
||||||
|
|
||||||
|
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
|
||||||
"""
|
"""
|
||||||
|
# Create the tokenizer directory in output_dir if it doesn't exist
|
||||||
|
tokenizer_dir = os.path.join(output_dir, "tokenizer")
|
||||||
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
|
||||||
def __init__(self, mistral_tokenizer: "MistralTokenizer", model_id: str):
|
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
|
||||||
self.mistral_tokenizer = mistral_tokenizer
|
# Load the tokenizer
|
||||||
self.model_id = model_id
|
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
|
||||||
self._system_prompt = None
|
|
||||||
self.padding_side = "right" # Default padding side
|
|
||||||
self.chat_template = None
|
|
||||||
|
|
||||||
# Cache token IDs by inspecting the actual tokenizer
|
# Save the tokenizer to the output directory
|
||||||
self._token_ids = self._discover_token_ids()
|
temp_tokenizer.save_pretrained(tokenizer_dir)
|
||||||
|
|
||||||
# Try to load system prompt if available
|
# Get the token IDs and map them to their new values
|
||||||
try:
|
|
||||||
self._system_prompt = self._load_system_prompt(
|
|
||||||
model_id, "SYSTEM_PROMPT.txt"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
LOG.debug(f"Could not load system prompt: {e}")
|
|
||||||
|
|
||||||
def _discover_token_ids(self) -> Dict[str, int]:
|
|
||||||
"""Discover the actual token IDs used by this Mistral tokenizer."""
|
|
||||||
token_ids = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
if hasattr(self.mistral_tokenizer, "instruct_tokenizer"):
|
|
||||||
instruct_tokenizer = self.mistral_tokenizer.instruct_tokenizer
|
|
||||||
|
|
||||||
# Get BOS token ID from instruct_tokenizer
|
|
||||||
token_ids["bos_token_id"] = getattr(instruct_tokenizer, "BOS", 1)
|
|
||||||
|
|
||||||
# Get token IDs from the underlying Tekkenizer
|
|
||||||
if hasattr(instruct_tokenizer, "tokenizer"):
|
|
||||||
tekkenizer = instruct_tokenizer.tokenizer
|
|
||||||
|
|
||||||
# Get BOS ID from tekkenizer (should match instruct_tokenizer.BOS)
|
|
||||||
if hasattr(tekkenizer, "bos_id"):
|
|
||||||
token_ids["bos_token_id"] = tekkenizer.bos_id
|
|
||||||
|
|
||||||
# Get vocab size to help find EOS token
|
|
||||||
vocab_size = getattr(tekkenizer, "_vocab_size", None)
|
|
||||||
|
|
||||||
# Check special tokens
|
|
||||||
if hasattr(tekkenizer, "_all_special_tokens"):
|
|
||||||
special_tokens = tekkenizer._all_special_tokens
|
|
||||||
keys = (
|
|
||||||
list(special_tokens.keys())
|
|
||||||
if hasattr(special_tokens, "keys")
|
|
||||||
else special_tokens
|
|
||||||
)
|
|
||||||
LOG.debug(f"Special tokens available: {keys}")
|
|
||||||
|
|
||||||
# Try to find EOS token in special tokens
|
|
||||||
if hasattr(special_tokens, "get"):
|
|
||||||
# Common EOS token patterns
|
|
||||||
for eos_pattern in ["</s>", "<|endoftext|>", "eos", "EOS"]:
|
|
||||||
if eos_pattern in special_tokens:
|
|
||||||
token_ids["eos_token_id"] = special_tokens[
|
|
||||||
eos_pattern
|
|
||||||
]
|
|
||||||
break
|
|
||||||
|
|
||||||
# Check special tokens reverse vocab
|
|
||||||
if hasattr(tekkenizer, "_special_tokens_reverse_vocab"):
|
|
||||||
reverse_vocab = tekkenizer._special_tokens_reverse_vocab
|
|
||||||
LOG.debug(f"Reverse special tokens: {reverse_vocab}")
|
|
||||||
|
|
||||||
# Look for common special token IDs
|
|
||||||
for token_id, token_str in reverse_vocab.items():
|
|
||||||
if token_str in ["</s>", "<|endoftext|>"]:
|
|
||||||
token_ids["eos_token_id"] = token_id
|
|
||||||
elif token_str in ["<unk>", "<UNK>"]:
|
|
||||||
token_ids["unk_token_id"] = token_id
|
|
||||||
|
|
||||||
# If we have vocab_size, EOS is often vocab_size - 1 or similar
|
|
||||||
if "eos_token_id" not in token_ids and vocab_size:
|
|
||||||
# Common patterns: EOS could be 2, vocab_size-1, or other values
|
|
||||||
# Let's try a safer approach by checking what tokens decode to
|
|
||||||
for candidate_id in [2, vocab_size - 1, vocab_size - 2]:
|
|
||||||
try:
|
|
||||||
# Try to decode and see if it looks like EOS
|
|
||||||
decoded = tekkenizer.decode([candidate_id])
|
|
||||||
if decoded in ["</s>", "<|endoftext|>", ""]:
|
|
||||||
token_ids["eos_token_id"] = candidate_id
|
|
||||||
break
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
LOG.debug(f"Could not discover token IDs: {e}")
|
|
||||||
|
|
||||||
# Set reasonable defaults for any missing token IDs
|
|
||||||
token_ids.setdefault("bos_token_id", 1)
|
|
||||||
token_ids.setdefault("eos_token_id", 2)
|
|
||||||
token_ids.setdefault("unk_token_id", 0)
|
|
||||||
token_ids.setdefault(
|
|
||||||
"pad_token_id", token_ids["eos_token_id"]
|
|
||||||
) # Use EOS as pad
|
|
||||||
|
|
||||||
LOG.info(f"Discovered Mistral token IDs: {token_ids}")
|
|
||||||
return token_ids
|
|
||||||
|
|
||||||
def _load_system_prompt(self, repo_id: str, filename: str) -> str:
|
|
||||||
"""Load system prompt from HuggingFace Hub"""
|
|
||||||
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
|
||||||
with open(file_path, "r") as file:
|
|
||||||
return file.read()
|
|
||||||
|
|
||||||
def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> List[int]:
|
|
||||||
"""Encode text to token IDs"""
|
|
||||||
if isinstance(text, str):
|
|
||||||
# For simple string encoding, create a user message
|
|
||||||
messages = []
|
|
||||||
if self._system_prompt and add_special_tokens:
|
|
||||||
messages.append(SystemMessage(content=self._system_prompt))
|
|
||||||
messages.append(UserMessage(content=text))
|
|
||||||
|
|
||||||
tokenized = self.mistral_tokenizer.encode_chat_completion(
|
|
||||||
ChatCompletionRequest(messages=messages)
|
|
||||||
)
|
|
||||||
return tokenized.tokens
|
|
||||||
else:
|
|
||||||
raise ValueError("MistralTokenizer wrapper only supports string input")
|
|
||||||
|
|
||||||
def decode(
|
|
||||||
self,
|
|
||||||
token_ids: Union[List[int], torch.Tensor],
|
|
||||||
skip_special_tokens: bool = True,
|
|
||||||
) -> str:
|
|
||||||
"""Decode token IDs to text"""
|
|
||||||
if isinstance(token_ids, torch.Tensor):
|
|
||||||
token_ids = token_ids.tolist()
|
|
||||||
return self.mistral_tokenizer.decode(token_ids)
|
|
||||||
|
|
||||||
def __call__(self, text: str, **kwargs):
|
|
||||||
"""Make the tokenizer callable like HF tokenizers"""
|
|
||||||
tokens = self.encode(text, **kwargs)
|
|
||||||
return {"input_ids": torch.tensor([tokens])}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def eos_token_id(self):
|
|
||||||
return self._token_ids["eos_token_id"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bos_token_id(self):
|
|
||||||
return self._token_ids["bos_token_id"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pad_token_id(self):
|
|
||||||
return self._token_ids["pad_token_id"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def unk_token_id(self):
|
|
||||||
return self._token_ids["unk_token_id"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def eos_token(self):
|
|
||||||
return "</s>" # Standard Mistral EOS token
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bos_token(self):
|
|
||||||
return "<s>" # Standard Mistral BOS token
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pad_token(self):
|
|
||||||
return self.eos_token # Use EOS as pad token
|
|
||||||
|
|
||||||
@property
|
|
||||||
def unk_token(self):
|
|
||||||
return "<unk>" # Standard UNK token
|
|
||||||
|
|
||||||
@property
|
|
||||||
def __class__(self):
|
|
||||||
# Create a mock class for compatibility checks
|
|
||||||
class MistralTokenizerWrapperClass:
|
|
||||||
__name__ = "MistralTokenizerWrapper"
|
|
||||||
|
|
||||||
return MistralTokenizerWrapperClass
|
|
||||||
|
|
||||||
def add_special_tokens(self, special_tokens_dict: Dict[str, str]) -> int:
|
|
||||||
"""Placeholder for special token addition - Mistral tokenizer handles this internally"""
|
|
||||||
LOG.warning(
|
|
||||||
"add_special_tokens called on MistralTokenizer wrapper - this is handled internally"
|
|
||||||
)
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def add_tokens(self, tokens) -> int:
|
|
||||||
"""Placeholder for token addition - Mistral tokenizer handles this internally"""
|
|
||||||
LOG.warning(
|
|
||||||
"add_tokens called on MistralTokenizer wrapper - this is handled internally"
|
|
||||||
)
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
class TokenizerFileModifier:
|
|
||||||
"""Handles modification of tokenizer files for token overrides."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
|
|
||||||
):
|
|
||||||
self.tokenizer_path = tokenizer_path
|
|
||||||
self.token_mappings = token_mappings
|
|
||||||
self.output_dir = output_dir
|
|
||||||
self.tokenizer_dir = os.path.join(output_dir, "tokenizer")
|
|
||||||
|
|
||||||
def modify_and_save(self) -> str:
|
|
||||||
"""Modify tokenizer files and return path to modified tokenizer."""
|
|
||||||
os.makedirs(self.tokenizer_dir, exist_ok=True)
|
|
||||||
|
|
||||||
if is_local_main_process():
|
|
||||||
self._perform_modifications()
|
|
||||||
barrier()
|
|
||||||
|
|
||||||
return self.tokenizer_dir
|
|
||||||
|
|
||||||
def _perform_modifications(self):
|
|
||||||
"""Perform the actual file modifications."""
|
|
||||||
# Load and save tokenizer to output directory
|
|
||||||
temp_tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
self.tokenizer_path, use_fast=True
|
|
||||||
)
|
|
||||||
temp_tokenizer.save_pretrained(self.tokenizer_dir)
|
|
||||||
|
|
||||||
# Convert token mappings to proper format
|
|
||||||
token_id_mappings = {
|
token_id_mappings = {
|
||||||
int(token_id): new_value
|
int(token_id): new_value for token_id, new_value in token_mappings.items()
|
||||||
for token_id, new_value in self.token_mappings.items()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Update both tokenizer files
|
# 1. Update tokenizer_config.json - added_tokens_decoder
|
||||||
self._update_tokenizer_config(token_id_mappings)
|
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
|
||||||
self._update_tokenizer_json(token_id_mappings)
|
if os.path.exists(config_path):
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
config_data = json.load(f)
|
||||||
|
|
||||||
def _update_tokenizer_config(self, token_id_mappings: Dict[int, str]):
|
# Update added_tokens_decoder
|
||||||
"""Update tokenizer_config.json with new token mappings."""
|
if "added_tokens_decoder" in config_data:
|
||||||
config_path = os.path.join(self.tokenizer_dir, "tokenizer_config.json")
|
for token_id, new_value in token_id_mappings.items():
|
||||||
if not os.path.exists(config_path):
|
token_id_str = str(token_id)
|
||||||
return
|
if token_id_str in config_data["added_tokens_decoder"]:
|
||||||
|
config_data["added_tokens_decoder"][token_id_str][
|
||||||
|
"content"
|
||||||
|
] = new_value
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Token ID {token_id_str} not found in added_tokens_decoder"
|
||||||
|
)
|
||||||
|
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
# Write the updated config back
|
||||||
config_data = json.load(f)
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(config_data, f, indent=2)
|
||||||
|
|
||||||
if "added_tokens_decoder" in config_data:
|
# 2. Update tokenizer.json - added_tokens
|
||||||
self._update_added_tokens_decoder(config_data, token_id_mappings)
|
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
||||||
|
if os.path.exists(tokenizer_path):
|
||||||
|
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
||||||
|
tokenizer_data = json.load(f)
|
||||||
|
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
# Update added_tokens
|
||||||
json.dump(config_data, f, indent=2)
|
if "added_tokens" in tokenizer_data:
|
||||||
|
for token_id, new_value in token_id_mappings.items():
|
||||||
|
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
|
||||||
|
if token_entry["id"] == token_id:
|
||||||
|
tokenizer_data["added_tokens"][i]["content"] = new_value
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
|
||||||
|
raise ValueError(
|
||||||
|
f"Token ID {token_id} not found in added_tokens"
|
||||||
|
)
|
||||||
|
if "model" in tokenizer_data and "vocab" in tokenizer_data["model"]:
|
||||||
|
for token_id, new_value in token_id_mappings.items():
|
||||||
|
for entry_val, entry_id in tokenizer_data["model"]["vocab"].items():
|
||||||
|
if entry_id == token_id:
|
||||||
|
del tokenizer_data["model"]["vocab"][entry_val]
|
||||||
|
tokenizer_data["model"]["vocab"][new_value] = token_id
|
||||||
|
break
|
||||||
|
|
||||||
def _update_added_tokens_decoder(
|
# Write the updated tokenizer data back
|
||||||
self, config_data: Dict, token_id_mappings: Dict[int, str]
|
with open(tokenizer_path, "w", encoding="utf-8") as f:
|
||||||
):
|
json.dump(tokenizer_data, f, indent=2)
|
||||||
"""Update the added_tokens_decoder section."""
|
|
||||||
for token_id, new_value in token_id_mappings.items():
|
|
||||||
token_id_str = str(token_id)
|
|
||||||
if token_id_str in config_data["added_tokens_decoder"]:
|
|
||||||
config_data["added_tokens_decoder"][token_id_str]["content"] = new_value
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Token ID {token_id_str} not found in added_tokens_decoder"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _update_tokenizer_json(self, token_id_mappings: Dict[int, str]):
|
barrier()
|
||||||
"""Update tokenizer.json with new token mappings."""
|
return tokenizer_dir
|
||||||
tokenizer_json_path = os.path.join(self.tokenizer_dir, "tokenizer.json")
|
|
||||||
if not os.path.exists(tokenizer_json_path):
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(tokenizer_json_path, "r", encoding="utf-8") as f:
|
|
||||||
tokenizer_data = json.load(f)
|
|
||||||
|
|
||||||
self._update_added_tokens_list(tokenizer_data, token_id_mappings)
|
|
||||||
self._update_vocab_mappings(tokenizer_data, token_id_mappings)
|
|
||||||
|
|
||||||
with open(tokenizer_json_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(tokenizer_data, f, indent=2)
|
|
||||||
|
|
||||||
def _update_added_tokens_list(
|
|
||||||
self, tokenizer_data: Dict, token_id_mappings: Dict[int, str]
|
|
||||||
):
|
|
||||||
"""Update the added_tokens list in tokenizer.json."""
|
|
||||||
if "added_tokens" not in tokenizer_data:
|
|
||||||
return
|
|
||||||
|
|
||||||
for token_id, new_value in token_id_mappings.items():
|
|
||||||
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
|
|
||||||
if token_entry["id"] == token_id:
|
|
||||||
tokenizer_data["added_tokens"][i]["content"] = new_value
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Token ID {token_id} not found in added_tokens")
|
|
||||||
|
|
||||||
def _update_vocab_mappings(
|
|
||||||
self, tokenizer_data: Dict, token_id_mappings: Dict[int, str]
|
|
||||||
):
|
|
||||||
"""Update vocab mappings in tokenizer.json."""
|
|
||||||
if not (tokenizer_data.get("model") and tokenizer_data["model"].get("vocab")):
|
|
||||||
return
|
|
||||||
|
|
||||||
vocab = tokenizer_data["model"]["vocab"]
|
|
||||||
for token_id, new_value in token_id_mappings.items():
|
|
||||||
# Find and update the vocab entry
|
|
||||||
for entry_val, entry_id in list(vocab.items()):
|
|
||||||
if entry_id == token_id:
|
|
||||||
del vocab[entry_val]
|
|
||||||
vocab[new_value] = token_id
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
class TokenizerConfiguration:
|
|
||||||
"""Handles tokenizer configuration and initialization."""
|
|
||||||
|
|
||||||
def __init__(self, cfg):
|
|
||||||
self.cfg = cfg
|
|
||||||
self.model_config = load_model_config(cfg)
|
|
||||||
|
|
||||||
def load_mistral_tokenizer(self) -> MistralTokenizerWrapper:
|
|
||||||
"""Load Mistral tokenizer from model configuration."""
|
|
||||||
# Instantiate Mistral tokenizer
|
|
||||||
model_id = self.cfg.base_model
|
|
||||||
mistral_tokenizer = MistralTokenizer.from_hf_hub(model_id)
|
|
||||||
|
|
||||||
# Wrap it for compatibility
|
|
||||||
tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id)
|
|
||||||
LOG.info(f"Loaded Mistral tokenizer for model: {model_id}")
|
|
||||||
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
def get_tokenizer_class(self):
|
|
||||||
"""Get the appropriate tokenizer class."""
|
|
||||||
if self.cfg.tokenizer_type:
|
|
||||||
return getattr(transformers, self.cfg.tokenizer_type)
|
|
||||||
return AutoTokenizer
|
|
||||||
|
|
||||||
def get_tokenizer_kwargs(self) -> Dict[str, Any]:
|
|
||||||
"""Build tokenizer initialization kwargs."""
|
|
||||||
kwargs = {}
|
|
||||||
if self.cfg.tokenizer_legacy is not None:
|
|
||||||
kwargs["legacy"] = self.cfg.tokenizer_legacy
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
def get_tokenizer_path(self) -> str:
|
|
||||||
"""Get the tokenizer path, applying overrides if needed."""
|
|
||||||
tokenizer_path = self.cfg.tokenizer_config
|
|
||||||
|
|
||||||
if self.cfg.added_tokens_overrides:
|
|
||||||
modifier = TokenizerFileModifier(
|
|
||||||
tokenizer_path, self.cfg.added_tokens_overrides, self.cfg.output_dir
|
|
||||||
)
|
|
||||||
tokenizer_path = modifier.modify_and_save()
|
|
||||||
|
|
||||||
return tokenizer_path
|
|
||||||
|
|
||||||
def should_use_fast_tokenizer(self) -> bool:
|
|
||||||
"""Determine if fast tokenizer should be used."""
|
|
||||||
return (
|
|
||||||
self.cfg.tokenizer_use_fast
|
|
||||||
if self.cfg.tokenizer_use_fast is not None
|
|
||||||
else True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TokenizerPostProcessor:
|
|
||||||
"""Handles post-processing configuration of loaded tokenizers."""
|
|
||||||
|
|
||||||
def __init__(self, tokenizer, cfg):
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.cfg = cfg
|
|
||||||
self.model_config = load_model_config(cfg)
|
|
||||||
|
|
||||||
def apply_all_configurations(self):
|
|
||||||
"""Apply all post-processing configurations to the tokenizer."""
|
|
||||||
# Skip most configurations for Mistral wrapper
|
|
||||||
if isinstance(self.tokenizer, MistralTokenizerWrapper):
|
|
||||||
self._configure_mistral_wrapper()
|
|
||||||
return
|
|
||||||
|
|
||||||
self._configure_padding_token()
|
|
||||||
self._configure_gptneox_settings()
|
|
||||||
self._configure_mistral_padding()
|
|
||||||
self._configure_qwen_tokens()
|
|
||||||
self._add_special_tokens()
|
|
||||||
self._add_regular_tokens()
|
|
||||||
self._configure_chat_template()
|
|
||||||
|
|
||||||
def _configure_mistral_wrapper(self):
|
|
||||||
"""Apply limited configurations for Mistral wrapper."""
|
|
||||||
# Set padding side if needed
|
|
||||||
if (
|
|
||||||
self.cfg.is_mistral_derived_model
|
|
||||||
and self.cfg.flash_attention
|
|
||||||
and not self.cfg.sample_packing
|
|
||||||
):
|
|
||||||
self.tokenizer.padding_side = "left"
|
|
||||||
|
|
||||||
# Configure chat template for Mistral
|
|
||||||
self._configure_chat_template()
|
|
||||||
|
|
||||||
def _configure_padding_token(self):
|
|
||||||
"""Configure padding token for Llama-based tokenizers."""
|
|
||||||
if (
|
|
||||||
self.tokenizer.__class__.__name__ in LLAMA_TOKENIZER_CLASSES
|
|
||||||
and hasattr(self.tokenizer, "pad_token")
|
|
||||||
and not self.tokenizer.pad_token
|
|
||||||
):
|
|
||||||
self.tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
|
||||||
|
|
||||||
def _configure_gptneox_settings(self):
|
|
||||||
"""Configure GPTNeoX-specific settings."""
|
|
||||||
if self.tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
|
||||||
self.tokenizer.add_special_tokens({"pad_token": GPTNEOX_PAD_TOKEN})
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
|
|
||||||
def _configure_mistral_padding(self):
|
|
||||||
"""Configure left padding for Mistral models with Flash Attention."""
|
|
||||||
if (
|
|
||||||
self.cfg.is_mistral_derived_model
|
|
||||||
and self.cfg.flash_attention
|
|
||||||
and not self.cfg.sample_packing
|
|
||||||
):
|
|
||||||
self.tokenizer.padding_side = "left"
|
|
||||||
|
|
||||||
def _configure_qwen_tokens(self):
|
|
||||||
"""Configure special tokens for Qwen models."""
|
|
||||||
if not self.cfg.is_qwen_derived_model:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Set token IDs
|
|
||||||
token_id_attributes = [
|
|
||||||
"bos_token_id",
|
|
||||||
"eos_token_id",
|
|
||||||
"pad_token_id",
|
|
||||||
"unk_token_id",
|
|
||||||
]
|
|
||||||
for attr_name in token_id_attributes:
|
|
||||||
if getattr(self.tokenizer, attr_name) is None:
|
|
||||||
setattr(self.tokenizer, attr_name, self.tokenizer.eod_id)
|
|
||||||
|
|
||||||
# Set token strings
|
|
||||||
token_name_attributes = ["bos_token", "eos_token", "pad_token", "unk_token"]
|
|
||||||
for attr_name in token_name_attributes:
|
|
||||||
if getattr(self.tokenizer, attr_name) is None:
|
|
||||||
setattr(self.tokenizer, attr_name, QWEN_DEFAULT_TOKEN)
|
|
||||||
|
|
||||||
def _add_special_tokens(self):
|
|
||||||
"""Add special tokens from configuration."""
|
|
||||||
if not self.cfg.special_tokens:
|
|
||||||
return
|
|
||||||
|
|
||||||
special_tokens_dict = self.cfg.special_tokens.to_dict()
|
|
||||||
additional_special_tokens = special_tokens_dict.pop(
|
|
||||||
"additional_special_tokens", None
|
|
||||||
)
|
|
||||||
|
|
||||||
self._validate_and_add_special_tokens(special_tokens_dict)
|
|
||||||
self._update_post_processor_if_needed(special_tokens_dict)
|
|
||||||
self._add_additional_special_tokens_if_present(additional_special_tokens)
|
|
||||||
|
|
||||||
def _validate_and_add_special_tokens(self, special_tokens: Dict[str, str]):
|
|
||||||
"""Validate special tokens for adapter training and add them."""
|
|
||||||
lora_modules_to_save = get_linear_embedding_layers(self.model_config.model_type)
|
|
||||||
|
|
||||||
for key, value in special_tokens.items():
|
|
||||||
self._validate_token_for_adapter(key, value, lora_modules_to_save)
|
|
||||||
self.tokenizer.add_special_tokens(
|
|
||||||
{key: AddedToken(value, rstrip=False, lstrip=False, normalized=False)}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _validate_token_for_adapter(
|
|
||||||
self, key: str, value: str, lora_modules_to_save: List[str]
|
|
||||||
):
|
|
||||||
"""Validate a single token for adapter training requirements."""
|
|
||||||
if not self._should_validate_token_for_adapter(
|
|
||||||
key, value, lora_modules_to_save
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
modules_str = ", ".join(f"`{x}`" for x in lora_modules_to_save)
|
|
||||||
raise ValueError(
|
|
||||||
f"Please set lora_modules_to_save to [{modules_str}] "
|
|
||||||
f"when using an adapter and changing the special tokens."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _should_validate_token_for_adapter(
|
|
||||||
self, key: str, value: str, lora_modules_to_save: List[str]
|
|
||||||
) -> bool:
|
|
||||||
"""Check if token should be validated for adapter configuration."""
|
|
||||||
if key == "pad_token" or not self.cfg.adapter:
|
|
||||||
return False
|
|
||||||
|
|
||||||
current_token = getattr(self.tokenizer, key)
|
|
||||||
token_changed = current_token is None or current_token != value
|
|
||||||
token_is_multi_char = (
|
|
||||||
len(self.tokenizer.encode(value, add_special_tokens=False)) > 2
|
|
||||||
)
|
|
||||||
lora_modules_missing = not self.cfg.lora_modules_to_save or not all(
|
|
||||||
x in self.cfg.lora_modules_to_save for x in lora_modules_to_save
|
|
||||||
)
|
|
||||||
|
|
||||||
return token_changed and token_is_multi_char and lora_modules_missing
|
|
||||||
|
|
||||||
def _update_post_processor_if_needed(self, special_tokens: Dict[str, str]):
|
|
||||||
"""Update post processor for Llama tokenizers when BOS/EOS tokens are added."""
|
|
||||||
has_bos_and_eos = (
|
|
||||||
"bos_token" in special_tokens and "eos_token" in special_tokens
|
|
||||||
)
|
|
||||||
is_fast_llama = (
|
|
||||||
self.tokenizer.__class__.__name__ in FAST_LLAMA_TOKENIZER_CLASSES
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_fast_llama and has_bos_and_eos:
|
|
||||||
self.tokenizer.update_post_processor()
|
|
||||||
|
|
||||||
def _add_additional_special_tokens_if_present(
|
|
||||||
self, additional_special_tokens: Optional[List[str]]
|
|
||||||
):
|
|
||||||
"""Add additional special tokens if they exist."""
|
|
||||||
if additional_special_tokens is not None:
|
|
||||||
self.tokenizer.add_special_tokens(
|
|
||||||
{"additional_special_tokens": additional_special_tokens}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_regular_tokens(self):
|
|
||||||
"""Add regular (non-special) tokens from configuration."""
|
|
||||||
if self.cfg.tokens:
|
|
||||||
self.tokenizer.add_tokens(
|
|
||||||
[
|
|
||||||
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
|
|
||||||
for token in self.cfg.tokens
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _configure_chat_template(self):
|
|
||||||
"""Configure chat template if specified."""
|
|
||||||
if not self.cfg.chat_template:
|
|
||||||
LOG.info(
|
|
||||||
"No Chat template selected. Consider adding a chat template for easier inference."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
chat_template_string = get_chat_template_from_config(
|
|
||||||
cfg=self.cfg,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._should_replace_default_system_message():
|
|
||||||
chat_template_string = chat_template_string.replace(
|
|
||||||
CHATML_DEFAULT_SYSTEM_MESSAGE, self.cfg.default_system_message
|
|
||||||
)
|
|
||||||
|
|
||||||
self.tokenizer.chat_template = chat_template_string
|
|
||||||
|
|
||||||
def _should_replace_default_system_message(self) -> bool:
|
|
||||||
"""Check if default system message should be replaced."""
|
|
||||||
return self.cfg.default_system_message and self.cfg.chat_template == "chatml"
|
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(cfg):
|
def load_tokenizer(cfg):
|
||||||
"""Load and configure the tokenizer based on the provided config.
|
"""Load and configure the tokenizer based on the provided config."""
|
||||||
|
model_config = load_model_config(cfg)
|
||||||
|
tokenizer_kwargs = {}
|
||||||
|
use_fast = True # this is the default
|
||||||
|
|
||||||
This function handles the complete tokenizer loading pipeline:
|
if cfg.tokenizer_use_fast is not None:
|
||||||
- Check if Mistral tokenizer should be used
|
use_fast = cfg.tokenizer_use_fast
|
||||||
- Configure tokenizer parameters and get the appropriate class
|
if cfg.tokenizer_legacy is not None:
|
||||||
- Handle token file modifications if needed
|
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
||||||
- Initialize the tokenizer with the correct parameters
|
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
||||||
- Apply all post-processing configurations (padding, special tokens, etc.)
|
|
||||||
- Set up chat templates and logging
|
|
||||||
|
|
||||||
Args:
|
tokenizer_cls = AutoTokenizer
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
if cfg.tokenizer_type:
|
||||||
|
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||||
|
|
||||||
Returns:
|
# Set base tokenizer path
|
||||||
Fully configured tokenizer instance.
|
tokenizer_path = cfg.tokenizer_config
|
||||||
"""
|
|
||||||
# Configure tokenizer parameters
|
|
||||||
config = TokenizerConfiguration(cfg)
|
|
||||||
|
|
||||||
# Check if we should use Mistral tokenizer
|
# Apply token string overrides if specified
|
||||||
try:
|
if cfg.added_tokens_overrides:
|
||||||
tokenizer = config.load_mistral_tokenizer()
|
# Modify tokenizer files and get path to modified tokenizer
|
||||||
except:
|
tokenizer_path = modify_tokenizer_files(
|
||||||
# Standard tokenizer loading
|
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
|
||||||
tokenizer_cls = config.get_tokenizer_class()
|
|
||||||
tokenizer_path = config.get_tokenizer_path()
|
|
||||||
use_fast = config.should_use_fast_tokenizer()
|
|
||||||
tokenizer_kwargs = config.get_tokenizer_kwargs()
|
|
||||||
|
|
||||||
# Initialize the tokenizer
|
|
||||||
tokenizer = tokenizer_cls.from_pretrained(
|
|
||||||
tokenizer_path,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
|
||||||
use_fast=use_fast,
|
|
||||||
**tokenizer_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply all post-processing configurations
|
tokenizer = tokenizer_cls.from_pretrained(
|
||||||
post_processor = TokenizerPostProcessor(tokenizer, cfg)
|
tokenizer_path,
|
||||||
post_processor.apply_all_configurations()
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
|
use_fast=use_fast,
|
||||||
|
**tokenizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
tokenizer.__class__.__name__
|
||||||
|
in [
|
||||||
|
"LlamaTokenizer",
|
||||||
|
"LlamaTokenizerFast",
|
||||||
|
"CodeLlamaTokenizer",
|
||||||
|
"CodeLlamaTokenizerFast",
|
||||||
|
]
|
||||||
|
and hasattr(tokenizer, "pad_token")
|
||||||
|
and not tokenizer.pad_token
|
||||||
|
):
|
||||||
|
# set a pad_token, but use eos_token so we don't add a new token
|
||||||
|
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
||||||
|
|
||||||
|
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||||
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
# Mistral's official FA implementation requires left padding
|
||||||
|
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
|
# Qwen base only has single token, so we need to set the special tokens
|
||||||
|
if cfg.is_qwen_derived_model:
|
||||||
|
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
|
||||||
|
for attr_name in token_ids:
|
||||||
|
if getattr(tokenizer, attr_name) is None:
|
||||||
|
setattr(tokenizer, attr_name, tokenizer.eod_id)
|
||||||
|
|
||||||
|
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
|
||||||
|
for attr_name in token_names:
|
||||||
|
if getattr(tokenizer, attr_name) is None:
|
||||||
|
setattr(tokenizer, attr_name, "<|endoftext|>")
|
||||||
|
|
||||||
|
additional_special_tokens = None
|
||||||
|
if cfg.special_tokens:
|
||||||
|
special_tokens = cfg.special_tokens.to_dict()
|
||||||
|
additional_special_tokens = special_tokens.pop(
|
||||||
|
"additional_special_tokens", None
|
||||||
|
)
|
||||||
|
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
|
||||||
|
for k, val in special_tokens.items():
|
||||||
|
# check if new special token is not already in tokenizer and
|
||||||
|
# is adapter training to make sure lora_modules_to_save is set
|
||||||
|
# pylint: disable=too-many-boolean-expressions
|
||||||
|
if (
|
||||||
|
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
|
||||||
|
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
|
||||||
|
and cfg.adapter
|
||||||
|
and (
|
||||||
|
not cfg.lora_modules_to_save
|
||||||
|
or not all(
|
||||||
|
x in cfg.lora_modules_to_save for x in lora_modules_to_save
|
||||||
|
)
|
||||||
|
)
|
||||||
|
and k != "pad_token"
|
||||||
|
):
|
||||||
|
lora_modules_to_save = ", ".join(
|
||||||
|
[f"`{x}`" for x in lora_modules_to_save]
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens."
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer.add_special_tokens(
|
||||||
|
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we add bos_token and eos_token, we need to update the post processor to
|
||||||
|
# handle them correctly.
|
||||||
|
# https://github.com/huggingface/transformers/pull/24132
|
||||||
|
bos_or_eos_in_special_tokens = (
|
||||||
|
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
tokenizer.__class__.__name__
|
||||||
|
in (
|
||||||
|
"LlamaTokenizerFast",
|
||||||
|
"CodeLlamaTokenizerFast",
|
||||||
|
)
|
||||||
|
and bos_or_eos_in_special_tokens
|
||||||
|
):
|
||||||
|
tokenizer.update_post_processor()
|
||||||
|
|
||||||
|
if cfg.tokens:
|
||||||
|
tokenizer.add_tokens(
|
||||||
|
[
|
||||||
|
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
|
||||||
|
for token in cfg.tokens
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Additional special tokens are a List, and need to be treated differently than regular special
|
||||||
|
# tokens. We add them after we have called `add_tokens` in case these additional special tokens
|
||||||
|
# are new tokens.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
#
|
||||||
|
# ```py
|
||||||
|
# special_tokens:
|
||||||
|
# additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
|
||||||
|
# ```
|
||||||
|
if additional_special_tokens is not None:
|
||||||
|
tokenizer.add_special_tokens(
|
||||||
|
{"additional_special_tokens": additional_special_tokens}
|
||||||
|
)
|
||||||
|
|
||||||
if is_main_process(use_environ=True):
|
if is_main_process(use_environ=True):
|
||||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||||
@@ -654,4 +263,19 @@ def load_tokenizer(cfg):
|
|||||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||||
|
|
||||||
|
if cfg.chat_template:
|
||||||
|
chat_template_string = get_chat_template_from_config(
|
||||||
|
cfg=cfg,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
if cfg.default_system_message and cfg.chat_template == "chatml":
|
||||||
|
chat_template_string = chat_template_string.replace(
|
||||||
|
"You are a helpful assistant.", cfg.default_system_message
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer.chat_template = chat_template_string
|
||||||
|
else:
|
||||||
|
LOG.info(
|
||||||
|
"No Chat template selected. Consider adding a chat template for easier inference."
|
||||||
|
)
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|||||||
@@ -67,10 +67,6 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
LOG.warning("Empty text requested for tokenization.")
|
LOG.warning("Empty text requested for tokenization.")
|
||||||
return empty
|
return empty
|
||||||
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
ipdb.set_trace()
|
|
||||||
|
|
||||||
result = self.tokenizer(
|
result = self.tokenizer(
|
||||||
prompt,
|
prompt,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
import typing
|
||||||
import weakref
|
import weakref
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -25,7 +28,6 @@ from axolotl.common.datasets import TrainDatasetMeta
|
|||||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
)
|
)
|
||||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.loaders import (
|
from axolotl.loaders import (
|
||||||
ModelLoader,
|
ModelLoader,
|
||||||
@@ -45,6 +47,9 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
BetterTransformer = None
|
BetterTransformer = None
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -472,7 +477,7 @@ def handle_untrained_tokens_fix(
|
|||||||
|
|
||||||
|
|
||||||
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
||||||
HFRLTrainerBuilder | HFCausalTrainerBuilder,
|
"HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
|
||||||
PeftModel | PreTrainedModel,
|
PeftModel | PreTrainedModel,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
PeftConfig | None,
|
PeftConfig | None,
|
||||||
|
|||||||
@@ -52,3 +52,10 @@ def patch_optimized_env():
|
|||||||
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
|
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
set_pytorch_cuda_alloc_conf()
|
set_pytorch_cuda_alloc_conf()
|
||||||
|
|
||||||
|
|
||||||
|
def get_not_null(value, default=None):
|
||||||
|
"""
|
||||||
|
return the value if it's not None, otherwise return the default value
|
||||||
|
"""
|
||||||
|
return value if value is not None else default
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,7 +1,7 @@
|
|||||||
"""Data collators for axolotl to pad labels and position_ids for packed sequences"""
|
"""Data collators for axolotl to pad labels and position_ids for packed sequences"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
@@ -161,7 +161,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
if not isinstance(features[0], list):
|
if not isinstance(features[0], list):
|
||||||
features = [features]
|
features: List[List[dict]] = [features]
|
||||||
out_features = [{} for _ in features]
|
out_features = [{} for _ in features]
|
||||||
for i, features_ in enumerate(features):
|
for i, features_ in enumerate(features):
|
||||||
for feature in features_[0].keys():
|
for feature in features_[0].keys():
|
||||||
|
|||||||
@@ -486,10 +486,6 @@ def get_dataset_wrapper(
|
|||||||
f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
|
f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
|
||||||
)
|
)
|
||||||
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
ipdb.set_trace()
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(dataset, Dataset)
|
isinstance(dataset, Dataset)
|
||||||
and "input_ids" in dataset.features
|
and "input_ids" in dataset.features
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ def retry_on_request_exceptions(
|
|||||||
except (
|
except (
|
||||||
requests.exceptions.ReadTimeout,
|
requests.exceptions.ReadTimeout,
|
||||||
requests.exceptions.ConnectionError,
|
requests.exceptions.ConnectionError,
|
||||||
|
requests.exceptions.HTTPError,
|
||||||
huggingface_hub.errors.HfHubHTTPError,
|
huggingface_hub.errors.HfHubHTTPError,
|
||||||
) as exc:
|
) as exc:
|
||||||
if attempt < max_retries - 1:
|
if attempt < max_retries - 1:
|
||||||
|
|||||||
@@ -258,7 +258,7 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
batch_max_len: int, # Maximum sequence length (bin capacity)
|
batch_max_len: int, # Maximum sequence length (bin capacity)
|
||||||
lengths: np.ndarray, # Sequence lengths
|
lengths: np.ndarray, # Sequence lengths
|
||||||
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
||||||
drop_last: bool = False, # Whether to drop final batches (might be incomplete)
|
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
|
||||||
num_count_samples: int = 16, # Number of times to estimate batch count
|
num_count_samples: int = 16, # Number of times to estimate batch count
|
||||||
sequential: bool = False, # Whether to use sequential packing
|
sequential: bool = False, # Whether to use sequential packing
|
||||||
group_size: int = 100_000, # Size of groups for parallel packing
|
group_size: int = 100_000, # Size of groups for parallel packing
|
||||||
@@ -443,10 +443,18 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
|
|
||||||
if self._len_across_ranks is None:
|
if self._len_across_ranks is None:
|
||||||
# Sample multiple times to get stable estimate
|
# Sample multiple times to get stable estimate
|
||||||
len_batches = min( # pylint: disable=consider-using-generator
|
_sampled_lens = []
|
||||||
[len(self._batches) for _ in range(self.num_count_samples)]
|
for _ in range(self.num_count_samples):
|
||||||
)
|
self._batches = None # Reset cached batches
|
||||||
|
_sampled_lens.append(len(self.generate_batches(set_stats=False)))
|
||||||
|
len_batches = min(_sampled_lens)
|
||||||
|
|
||||||
# Gather minimum across all ranks
|
# Gather minimum across all ranks
|
||||||
self._len_across_ranks = self.gather_len_batches(len_batches)
|
if self._len_across_ranks is None:
|
||||||
|
self._len_across_ranks = self.gather_len_batches(len_batches)
|
||||||
|
else:
|
||||||
|
self._len_across_ranks = min(
|
||||||
|
self._len_across_ranks, self.gather_len_batches(len_batches)
|
||||||
|
)
|
||||||
|
|
||||||
return self._len_across_ranks
|
return self._len_across_ranks
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from datasets import IterableDataset, disable_caching, enable_caching
|
|||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
|
||||||
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
||||||
from axolotl.utils.distributed import reduce_and_broadcast
|
from axolotl.utils.distributed import reduce_and_broadcast
|
||||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
@@ -482,6 +481,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if cfg.dataloader_drop_last:
|
||||||
|
# drop the last batch for each epoch
|
||||||
|
total_num_steps -= int(math.ceil(cfg.num_epochs))
|
||||||
|
|
||||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||||
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||||
@@ -629,6 +631,8 @@ def setup_trainer(
|
|||||||
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
|
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
|
||||||
on the provided parameters.
|
on the provided parameters.
|
||||||
"""
|
"""
|
||||||
|
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cfg.torch_compile
|
cfg.torch_compile
|
||||||
and cfg.fsdp_config
|
and cfg.fsdp_config
|
||||||
|
|||||||
@@ -1,4 +1,8 @@
|
|||||||
"""Test cases for tokenizer loading."""
|
"""
|
||||||
|
Test cases for the tokenizer loading
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -9,7 +13,9 @@ from tests.hf_offline_utils import enable_hf_offline
|
|||||||
|
|
||||||
|
|
||||||
class TestTokenizers:
|
class TestTokenizers:
|
||||||
"""Test class for the load_tokenizer fn"""
|
"""
|
||||||
|
test class for the load_tokenizer fn
|
||||||
|
"""
|
||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_default_use_fast(self):
|
def test_default_use_fast(self):
|
||||||
@@ -149,50 +155,6 @@ class TestTokenizers:
|
|||||||
):
|
):
|
||||||
load_tokenizer(cfg)
|
load_tokenizer(cfg)
|
||||||
|
|
||||||
def test_mistral_tokenizer_auto_detection(self):
|
|
||||||
"""Test that Mistral models are auto-detected and use MistralTokenizerWrapper"""
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
|
|
||||||
"tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
|
||||||
assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper"
|
|
||||||
|
|
||||||
def test_mixtral_tokenizer_auto_detection(self):
|
if __name__ == "__main__":
|
||||||
"""Test that Mixtral models are auto-detected and use MistralTokenizerWrapper"""
|
unittest.main()
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "model-hub/Mixtral-8x7B-v0.1",
|
|
||||||
"tokenizer_config": "model-hub/Mixtral-8x7B-v0.1",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
|
||||||
assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper"
|
|
||||||
|
|
||||||
def test_mistral_tokenizer_basic_functionality(self):
|
|
||||||
"""Test basic encode/decode functionality of MistralTokenizerWrapper"""
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
|
|
||||||
"tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
|
||||||
|
|
||||||
# Test basic encoding
|
|
||||||
text = "Hello, world!"
|
|
||||||
tokens = tokenizer.encode(text)
|
|
||||||
assert isinstance(tokens, list)
|
|
||||||
assert len(tokens) > 0
|
|
||||||
|
|
||||||
# Test basic decoding
|
|
||||||
decoded = tokenizer.decode(tokens)
|
|
||||||
assert isinstance(decoded, str)
|
|
||||||
|
|
||||||
# Test token properties are accessible
|
|
||||||
assert hasattr(tokenizer, "eos_token_id")
|
|
||||||
assert hasattr(tokenizer, "bos_token_id")
|
|
||||||
assert isinstance(tokenizer.eos_token_id, int)
|
|
||||||
assert isinstance(tokenizer.bos_token_id, int)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user