Compare commits

..

33 Commits

Author SHA1 Message Date
Dan Saunders
345a159796 coderabbit comments 2025-06-07 04:50:29 +00:00
Dan Saunders
657bffd85f update posthog dep 2025-06-05 23:46:20 +00:00
Dan Saunders
f0dde8e2d5 lint 2025-06-05 23:41:46 +00:00
Dan Saunders
25fa4df70f fix 2025-06-05 23:33:46 +00:00
Dan Saunders
e735f4270b slight changes 2025-06-05 23:33:46 +00:00
Dan Saunders
035e7a2f4c simplifying 2025-06-05 23:33:46 +00:00
Dan Saunders
2d36c11264 minor fixes 2025-06-05 23:33:46 +00:00
Dan Saunders
b8ec5bdccf doc update 2025-06-05 23:33:44 +00:00
Dan Saunders
249405b46e docs fix 2025-06-05 23:31:44 +00:00
Dan Saunders
d3be84fec2 enable / disable logic update 2025-06-05 23:31:44 +00:00
Dan Saunders
1c74ab175f opt-in version of telemetry 2025-06-05 23:31:44 +00:00
Dan Saunders
b2f1fc109a distributed fix 2025-06-05 23:31:44 +00:00
Dan Saunders
5a2a80cc48 fix issue with tests in ci 2025-06-05 23:31:44 +00:00
Dan Saunders
4033fe74f8 fixes 2025-06-05 23:31:44 +00:00
Dan Saunders
e9df4444be remove duplicate info 2025-06-05 23:31:44 +00:00
Dan Saunders
ffd2985750 adding runtime metrics / system info additional accelerator support, etc. 2025-06-05 23:31:44 +00:00
Dan Saunders
17310f9acc adding runtime metrics / system info additional accelerator support, etc. 2025-06-05 23:31:44 +00:00
Dan Saunders
71ae6f9f87 improved redaction, send system info during model config load telemetry, etc. 2025-06-05 23:31:08 +00:00
Dan Saunders
9dd1092f8f doc update 2025-06-05 23:27:29 +00:00
Dan Saunders
2c2f2647a9 fix 2025-06-05 23:27:29 +00:00
Dan Saunders
98313a6b3f adding back in base_model redaction w/ whitelist 2025-06-05 23:27:29 +00:00
Dan Saunders
8b75205d3b sleep on all ranks in distributed setting 2025-06-05 23:27:29 +00:00
Dan Saunders
ef4990f304 simplifying path redaction 2025-06-05 23:27:29 +00:00
Dan Saunders
db3297b090 small update / fix 2025-06-05 23:27:27 +00:00
Dan Saunders
86ed554bda tests for runtime metrics telemetry and assoc. callback 2025-06-05 23:26:07 +00:00
Dan Saunders
f254d7d5a2 adding runtime metrics (cpu + gpu memory, steps/s, etc.) 2025-06-05 23:26:05 +00:00
Dan Saunders
d8b0522ea0 updated sanitization logic, tests 2025-06-05 23:20:51 +00:00
Dan Saunders
1edd6b9524 update error file path sanitization function; adding more error tracking 2025-06-05 23:20:49 +00:00
Dan Saunders
66c6fb56cb progress on telemetry: config load, process, model load, train start / end, error tracking 2025-06-05 22:59:50 +00:00
Dan Saunders
90b39ce112 updates 2025-06-05 22:49:15 +00:00
Dan Saunders
5afab46cc6 updates 2025-06-05 22:49:15 +00:00
Dan Saunders
bd152c6115 adding todo 2025-06-05 22:49:15 +00:00
Dan Saunders
76336743ff initial telemetry manager impl 2025-06-05 22:49:14 +00:00
57 changed files with 2959 additions and 2138 deletions

View File

@@ -112,6 +112,13 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
## 📈 Telemetry
Axolotl has opt-in telemetry that helps us understand how the project is being used
and prioritize improvements. We collect basic system information, model types, and
error rates—never personal data or file paths. Telemetry is disabled by default. To
enable it, set AXOLOTL_DO_NOT_TRACK=0. For more details, see our [telemetry documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html).
## Supported Models
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |

View File

@@ -236,6 +236,7 @@ website:
- docs/inference.qmd
- docs/cli.qmd
- docs/config.qmd
- docs/telemetry.qmd
- text: "API Reference"
href: docs/api

View File

@@ -1,31 +0,0 @@
{
"compile": {
"disable": false,
"backend": "inductor"
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

59
docs/telemetry.qmd Normal file
View File

@@ -0,0 +1,59 @@
---
title: Telemetry
description: A description of the opt-in telemetry implementation in Axolotl.
---
# Telemetry in Axolotl
Axolotl implements anonymous telemetry to help maintainers understand how the library
is used and where users encounter issues. This data helps prioritize features, optimize
performance, and fix bugs.
## Data Collection
We collect:
- System info: OS, Python version, Axolotl version, PyTorch version, Transformers
version, etc.
- Hardware info: CPU count, memory, GPU count and models
- Runtime metrics: Training progress, memory usage, timing information
- Usage patterns: Models (from a whitelist) and configurations used
- Error tracking: Stack traces and error messages (sanitized to remove personal
information)
Personally identifiable information (PII) is not collected.
## Implementation
Telemetry is implemented using PostHog and consists of:
- `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the
telemetry system and provides methods for tracking events.
- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and
sends sanitized stack traces.
- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks
runtime metrics during training.
- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends
runtime metrics telemetry.
The telemetry system will block training startup for 15 seconds to ensure users are
aware of data collection, unless telemetry is explicitly enabled or disabled.
## Opt-In Mechanism
Telemetry is **disabled by default** on an opt-in basis. To enable it, set `AXOLOTL_DO_NOT_TRACK=0`.
To remove the warning message about telemetry that is displayed on train, etc. startup,
explicitly set: `AXOLOTL_DO_NOT_TRACK=0` (enable telemetry) or `AXOLOTL_DO_NOT_TRACK=1`
(explicitly disable telemetry).
**Note**: Telemetry will move to an opt-out model in a later release.
## Privacy
- All path-like config information is automatically redacted from telemetry data
- Model information is only collected for whitelisted organizations
- See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations
- Each run generates a unique anonymous ID
- This allows us to link different telemetry events in a single same training run
- Telemetry is only sent from the main process to avoid duplicate events

View File

@@ -67,3 +67,6 @@ schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3
# telemetry
posthog>=4.2.0

View File

@@ -14,6 +14,8 @@ import yaml
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
@@ -28,6 +30,8 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__, use_environ=True)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
"""
@@ -159,6 +163,7 @@ def plugin_set_cfg(cfg: DictDefault):
plugin_manager.cfg = cfg
@send_errors
def load_cfg(
config: str | Path | DictDefault = Path("examples/"), **kwargs
) -> DictDefault:
@@ -192,6 +197,8 @@ def load_cfg(
temp_file.close()
cfg.axolotl_config_path = temp_file.name
TELEMETRY_MANAGER.send_event(event_type="config-loaded", properties=cfg)
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
@@ -233,4 +240,6 @@ def load_cfg(
setup_comet_env_vars(cfg)
plugin_set_cfg(cfg)
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
return cfg

View File

@@ -16,6 +16,7 @@ from axolotl.cli.args import InferenceCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
@@ -42,6 +43,7 @@ def get_multi_line_input() -> str:
return instruction
@send_errors
def do_inference(
*,
cfg: DictDefault,
@@ -135,6 +137,7 @@ def do_inference(
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
@send_errors
def do_inference_gradio(
*,
cfg: DictDefault,

View File

@@ -9,12 +9,14 @@ from dotenv import load_dotenv
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@send_errors
def do_merge_lora(*, cfg: DictDefault) -> None:
"""
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config

View File

@@ -24,6 +24,7 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.telemetry.errors import send_errors
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@@ -118,6 +119,7 @@ def _distributed_checkpoint_to_merged_weights(
return save_path_
@send_errors
def merge_fsdp_weights(
checkpoint_dir: str,
output_path: str,

View File

@@ -18,6 +18,7 @@ from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.trainer import disable_datasets_caching
@@ -25,6 +26,7 @@ from axolotl.utils.trainer import disable_datasets_caching
LOG = get_logger(__name__)
@send_errors
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
"""
Preprocesses dataset specified in axolotl config.

View File

@@ -10,6 +10,7 @@ from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
@@ -45,6 +46,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
)
@send_errors
def load_datasets(
*,
cfg: DictDefault,
@@ -112,6 +114,7 @@ def load_datasets(
)
@send_errors
def load_preference_datasets(
*,
cfg: DictDefault,

View File

@@ -31,6 +31,8 @@ from transformers.training_args import OptimizerNames
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.telemetry.callbacks import TelemetryCallback
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
GCCallback,
@@ -145,6 +147,10 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(GPUStatsCallback(cfg=self.cfg))
telemetry_manager = TelemetryManager.get_instance()
if telemetry_manager.enabled:
callbacks.append(TelemetryCallback())
return callbacks
def get_post_trainer_create_callbacks(self, trainer):

View File

@@ -21,6 +21,11 @@ from axolotl.core.trainers import (
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.core.training_args import (
AxolotlPRMConfig,
AxolotlRewardConfig,
AxolotlTrainingArguments,
)
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
@@ -125,9 +130,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks
def _get_trainer_cls(self):
"""
Gets the trainer class for the given configuration.
"""
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
@@ -144,12 +146,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return AxolotlTrainer
def build(self, total_num_steps):
from axolotl.core.training_args import (
AxolotlPRMConfig,
AxolotlRewardConfig,
AxolotlTrainingArguments,
)
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps
)
@@ -318,12 +314,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["image_resize_algorithm"] = (
self.cfg.image_resize_algorithm
)
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
plugin_training_args = plugin_manager.get_training_args(self.cfg)
if plugin_training_args:
training_arguments_kwargs.update(plugin_training_args)
if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None:
training_arguments_kwargs["kd_zscore_base_temp"] = (
self.cfg.kd_zscore_base_temp
)
if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs["kd_top_k_before_softmax"] = (
self.cfg.kd_top_k_before_softmax
)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
@@ -404,10 +408,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return trainer
def build_collator(
self,
training_args, # type: "AxolotlTrainingArguments" # type: ignore
is_eval=False,
**kwargs,
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if (
@@ -436,18 +437,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
]
]
collator_args = [self.tokenizer]
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs(
self.cfg, is_eval=is_eval
)
if collator_cls_and_kwargs:
collator = collator_cls_and_kwargs[0]
if kwargs and isinstance(kwargs, dict):
kwargs.update(collator_cls_and_kwargs[1])
elif self.cfg.reward_model:
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator:
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
@@ -478,6 +468,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator_args.pop(0)
kwargs.pop("pad_to_multiple_of", None)
kwargs.pop("padding", None)
elif self.cfg.kd_trainer:
from axolotl.integrations.kd.collator import (
DataCollatorForKD,
KDBatchSamplerDataCollatorForSeq2Seq,
)
if self.cfg.sample_packing:
collator = KDBatchSamplerDataCollatorForSeq2Seq
else:
collator = DataCollatorForKD
else:
collator = DataCollatorForSeq2Seq

View File

@@ -12,6 +12,11 @@ from axolotl.core.trainers import (
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlKTOConfig,
AxolotlORPOConfig,
)
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.logging import get_logger
@@ -74,12 +79,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
"""
Returns training_args and trainer_kwargs
"""
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlKTOConfig,
AxolotlORPOConfig,
)
training_args_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps=total_num_steps
)
@@ -166,13 +165,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
plugin_training_args = plugin_manager.get_training_args(self.cfg)
if plugin_training_args:
training_args_kwargs.update(plugin_training_args)
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
logging_first_step=True,
**training_args_kwargs,

View File

@@ -33,7 +33,6 @@ from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
from axolotl.utils import get_not_null
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -102,7 +101,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
)
batch_max_len = train_batch_size * self.args.max_seq_length
sampler = MultipackBatchSampler(
return MultipackBatchSampler(
base_sampler,
lengths=get_dataset_lengths(dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
@@ -114,9 +113,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
drop_last=True,
)
len(sampler)
return sampler
def _get_train_sampler(
self, train_dataset: Optional[Dataset] = None
) -> Optional[Sampler]:
@@ -224,9 +220,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
}
if not isinstance(dataset, torch.utils.data.IterableDataset):
dataloader_params["drop_last"] = get_not_null(
self.args.dataloader_drop_last, True
)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
if sampler_fn is not None:
sampler = sampler_fn(dataset)
if isinstance(sampler, BatchSampler):

View File

@@ -2,17 +2,238 @@
extra axolotl specific training args
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional, Type
from typing import Optional
from PIL.Image import Resampling
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.integrations.config import merge_training_args
AxolotlTrainingMixins: Type = merge_training_args()
@dataclass
class AxolotlTrainingMixins:
"""
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
sample_packing_sequentially: bool = field(
default=False,
metadata={
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
sample_packing_bin_size: int = field(
default=200,
metadata={
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
},
)
sample_packing_group_size: int = field(
default=100000,
metadata={
"help": "The number of samples to group together for packing. Increase for better packing."
},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
dataloader_prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)
cosine_min_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
embedding_lr_scale: Optional[float] = field(
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
orpo_alpha: Optional[float] = field(
default=None,
)
lisa_n_layers: Optional[int] = field(
default=None,
metadata={"help": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = field(
default=None,
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = field(
default=None,
metadata={"help": "path under the model to access the layers"},
)
curriculum_sampling: Optional[bool] = field(
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)
kd_ce_alpha: Optional[float] = field(
default=None,
metadata={
"help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
},
)
kd_alpha: Optional[float] = field(
default=1.0,
metadata={"help": "The alpha scaling parameter for KD loss"},
)
kd_temperature: Optional[float] = field(
default=1.0,
metadata={
"help": "the temperature parameter for KL divergence loss when using KD"
},
)
kd_zscore_base_temp: Optional[float] = field(
default=None,
metadata={
"help": "the base temperature parameter for KL divergence with z-score when using KD"
},
)
kd_top_k_before_softmax: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
},
)
adam_beta3: Optional[float] = field(
default=None,
metadata={
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
},
)
adam_epsilon2: Optional[float] = field(
default=None,
metadata={
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(
default=None,
metadata={"help": "The size of the image to resize to"},
)
image_resize_algorithm: Resampling | None = field(
default=None,
metadata={"help": "The algorithm to use for image resizing"},
)
# end of multi-modal section
@dataclass

View File

@@ -1,220 +0,0 @@
"""
Base Axolotl Training Mixins shared across various trainer configs
"""
from dataclasses import dataclass, field
from typing import Optional
from PIL.Image import Resampling
@dataclass
class AxolotlTrainingMixins:
"""
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
sample_packing_sequentially: bool = field(
default=False,
metadata={
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
sample_packing_bin_size: int = field(
default=200,
metadata={
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
},
)
sample_packing_group_size: int = field(
default=100000,
metadata={
"help": "The number of samples to group together for packing. Increase for better packing."
},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
dataloader_prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)
cosine_min_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
embedding_lr_scale: Optional[float] = field(
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
orpo_alpha: Optional[float] = field(
default=None,
)
lisa_n_layers: Optional[int] = field(
default=None,
metadata={"help": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = field(
default=None,
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = field(
default=None,
metadata={"help": "path under the model to access the layers"},
)
curriculum_sampling: Optional[bool] = field(
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)
# kd_ce_alpha: Optional[float] = field(
# default=None,
# metadata={
# "help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
# },
# )
#
# kd_alpha: Optional[float] = field(
# default=1.0,
# metadata={"help": "The alpha scaling parameter for KD loss"},
# )
#
# kd_temperature: Optional[float] = field(
# default=1.0,
# metadata={
# "help": "the temperature parameter for KL divergence loss when using KD"
# },
# )
adam_beta3: Optional[float] = field(
default=None,
metadata={
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
},
)
adam_epsilon2: Optional[float] = field(
default=None,
metadata={
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(
default=None,
metadata={"help": "The size of the image to resize to"},
)
image_resize_algorithm: Resampling | None = field(
default=None,
metadata={"help": "The algorithm to use for image resizing"},
)
# end of multi-modal section

View File

@@ -11,6 +11,7 @@ from accelerate.logging import get_logger
from datasets import Dataset
from transformers.trainer import Trainer
from axolotl.telemetry.errors import send_errors
from axolotl.train import (
TrainDatasetMeta,
setup_model_and_tokenizer,
@@ -63,6 +64,7 @@ def evaluate_dataset(
return metrics
@send_errors
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
"""
Evaluate a model on training and validation datasets.

View File

@@ -22,7 +22,6 @@ from __future__ import annotations
import collections
import importlib
import traceback
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
from peft import PeftModel
@@ -84,11 +83,6 @@ class BasePlugin:
def get_input_args(self) -> str | None:
"""Returns a pydantic model for the plugin's input arguments."""
def get_training_args_mixin(self) -> str | None:
"""
Returns a dataclass model for the plugin's training arguments.
"""
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
@@ -164,31 +158,6 @@ class BasePlugin:
trainer: The trainer object for training.
"""
def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument):
"""
Returns custom training arguments to set on TrainingArgs.
Args:
cfg: The global axolotl configuration.
Returns:
object: dict containing the training arguments.
"""
def get_collator_cls_and_kwargs(
self, cfg: DictDefault, is_eval: bool = False
): # pylint: disable=unused-argument):
"""
Returns a custom class for the collator.
Args:
cfg: The global axolotl configuration.
is_eval: Whether this is an eval split.
Returns:
class: The class for the collator.
"""
# pylint: disable=unused-argument
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
"""Creates and returns an optimizer for training.
@@ -309,7 +278,7 @@ def load_plugin(plugin_name: str) -> BasePlugin:
return plugin
class PluginManager: # pylint: disable=too-many-public-methods
class PluginManager:
"""The `PluginManager` class is responsible for loading and managing plugins. It
should be a singleton so it can be accessed from anywhere in the codebase.
@@ -368,11 +337,8 @@ class PluginManager: # pylint: disable=too-many-public-methods
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
LOG.info(f"Plugin loaded successfully: {plugin_name}")
except ImportError as exc:
except ImportError:
LOG.error(f"Failed to load plugin: {plugin_name}")
# print stacktrace
traceback.print_exc()
print(f"Error: {exc}")
def get_input_args(self) -> list[str]:
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
@@ -387,20 +353,6 @@ class PluginManager: # pylint: disable=too-many-public-methods
input_args.append(input_args_from_plugin)
return input_args
def get_training_args_mixin(self):
"""
Returns a list of dataclasses for all registered plugins' training args mixins'
Returns:
list[str]: A list of dataclsses
"""
training_args = []
for plugin in self.plugins.values():
training_args_from_plugin = plugin.get_training_args_mixin()
if training_args_from_plugin is not None:
training_args.append(training_args_from_plugin)
return training_args
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
@@ -490,42 +442,6 @@ class PluginManager: # pylint: disable=too-many-public-methods
return trainer_cls
return None
def get_training_args(self, cfg):
"""
Calls the get_training_args method of all registered plugins and returns the combined training arguments.
Parameters:
cfg (dict): The configuration for the plugins.
Returns:
object: The training arguments
"""
training_args_kwargs = {}
for plugin in self.plugins.values():
training_args = plugin.get_training_args(cfg)
if training_args is not None:
training_args_kwargs.update(training_args)
return training_args_kwargs
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
"""
Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.
Parameters:
cfg (dict): The configuration for the plugins.
is_eval (bool): Whether this is an eval split.
Returns:
object: The collator class, or None if none was found.
"""
for plugin in self.plugins.values():
collator = plugin.get_collator_cls_and_kwargs(cfg, is_eval=is_eval)
if collator is not None:
collator_cls, collator_kwargs = collator
return collator_cls, collator_kwargs
return None
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Calls the `post_trainer_create` method of all registered plugins.

View File

@@ -16,7 +16,7 @@ Module to handle merging the plugins' input arguments with the base configuratio
This was moved here to prevent circular imports.
"""
from typing import Any, Dict, List, Type
from typing import Any, Dict, List
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
@@ -61,43 +61,3 @@ def merge_input_args():
]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
def merge_training_args() -> Type:
"""
Merges training arguments from registered plugins with the base TrainingArguments.
This function retrieves the training arguments from registered plugins using the PluginManager.
It then dynamically creates new classes, AxolotlTrainingMixins,
that inherit from the base configurations and include the training arguments from the plugins.
Returns:
tuple: A tuple containing the newly created classes, AxolotlTrainingMixins.
"""
# pylint: disable=duplicate-code
from axolotl.core.training_args_base import (
AxolotlTrainingMixins as AxolotlTrainingMixinsBase,
)
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
training_args_mixins: List[str] = plugin_manager.get_training_args_mixin()
mixin_classes = []
dynamic_input = ""
for plugin_args in training_args_mixins:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
mixin_classes.append(plugin_cls)
if dynamic_input:
dynamic_input += f"class AxolotlTrainingMixins(AxolotlTrainingMixinsBase, {', '.join(mixin_classes)}):\n pass\n"
namespace: Dict[Any, Any] = {}
local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase}
exec( # pylint: disable=exec-used # nosec B102
dynamic_input, {**globals(), **local_vars}, namespace
)
AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name
"AxolotlTrainingMixins"
]
return AxolotlTrainingMixins
return AxolotlTrainingMixinsBase

View File

@@ -21,32 +21,3 @@ datasets:
```
An example dataset can be found at [`axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample`](https://huggingface.co/datasets/axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample)
## Online KD (sglang)
```bash
export UV_TORCH_BACKEND=cu124
uv venv sglang --python 3.11
source sglang/bin/activate
uv pip install --upgrade pip
uv pip install setuptools
uv pip install torch~=2.5.1 --index-url https://download.pytorch.org/whl/cu124
uv pip install sgl-kernel --force-reinstall --no-deps
uv pip install "sglang[all]>=0.4.2.post4" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/
```
## Online KD (vllm)
```bash
VLLM_USE_V1=0 vllm serve open-r1/OlympicCoder-32B --max-model-len 16400 --port 8888 --max-logprobs 128 --return-tokens-as-token-ids --tensor-parallel-size 8 --max-num-seqs
256 --gpu_memory_utilization 0.2 --enable-chunked-prefill
```
```bash
vllm serve open-r1/OlympicCoder-32B --max-model-len 16400 --port 8888 --max-logprobs 128 --return-tokens-as-token-ids --tensor-parallel-size 8 --no-enable-prefix-caching --gpu-memory-utilization 0.3 --max-num-batched-tokens 131072 --host 0.0.0.0
```
```bash
python -m sglang.launch_server --model-path open-r1/OlympicCoder-32B --tensor-parallel-size 8 --port 8080 --host 0.0.0.0 --max-running-requests 256 --context-length 16400 --mem-fraction-static 0.2 --schedule-conservativeness 0.3 --chunked-prefill-size 131072 --schedule-policy fcfs --skip-tokenizer-init
```

View File

@@ -15,12 +15,7 @@
"""
Plugin init to add KD support to Axolotl.
"""
from typing import Any
from transformers import Trainer
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
@@ -33,75 +28,9 @@ class KDPlugin(BasePlugin):
def get_input_args(self):
return "axolotl.integrations.kd.KDArgs"
def get_training_args_mixin(self):
return "axolotl.integrations.kd.args.KDTrainingArgsMixin"
def get_trainer_cls(self, cfg):
if cfg.kd_trainer:
from .trainer import AxolotlKDTrainer
return AxolotlKDTrainer
return None
def get_training_args(self, cfg):
return {
"kd_ce_alpha": cfg.kd_ce_alpha,
"kd_alpha": cfg.kd_alpha,
"kd_temperature": cfg.kd_temperature,
"kd_beta": cfg.kd_beta,
"kd_normalize_topk": cfg.kd_normalize_topk,
}
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
if not cfg.kd_trainer:
return None, None
from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq
use_batch_sampler_collator = False
if is_eval is False and cfg.sample_packing:
use_batch_sampler_collator = True
if cfg.eval_sample_packing and is_eval:
use_batch_sampler_collator = True
if cfg.kd_online_server_base_url:
from .collator_online_teacher import OnlineTeacherCollator
return OnlineTeacherCollator, {
"kd_online_server_base_url": cfg.kd_online_server_base_url,
"kd_online_topk": cfg.kd_online_topk,
"kd_temperature": cfg.kd_temperature,
"kd_online_server": cfg.kd_online_server,
"kd_online_timeout": cfg.kd_online_timeout,
"kd_normalize_topk": cfg.kd_normalize_topk,
}
if use_batch_sampler_collator:
return KDBatchSamplerDataCollatorForSeq2Seq, {}
return DataCollatorForKD, {}
def pre_model_load(self, cfg):
from .kernels.models import apply_kernel
apply_kernel(cfg.model_config_type)
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
"""
Adds temp scheduler callback to the Trainer instance.
Args:
cfg (Any): Configuration object containing the sparse recipe.
trainer (Trainer): Huggingface Trainer instance.
Returns:
list: List containing the configured callback instances.
"""
if cfg.kd_temperature_min is not None and cfg.kd_online_server_base_url:
callback = KDTemperatureSchedulerCallback(
cfg.kd_temperature,
cfg.kd_temperature_min,
trainer,
)
return [callback]
return []

View File

@@ -15,19 +15,9 @@
"""
Plugin args for KD support.
"""
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class InferenceServerType(str, Enum):
"""
Online inferences server types to handle different request args
"""
vllm = "vllm" # pylint: disable=invalid-name
sglang = "sglang" # pylint: disable=invalid-name
from pydantic import BaseModel
class KDArgs(BaseModel):
@@ -35,41 +25,13 @@ class KDArgs(BaseModel):
Input args for knowledge distillation.
"""
kd_trainer: float | None = None # whether to use KD trainer
kd_ce_alpha: float | None = (
kd_trainer: Optional[bool] = None # whether to use KD trainer
kd_ce_alpha: Optional[float] = (
None # loss coefficient for cross-entropy loss during KD
)
kd_alpha: float | None = None # loss coefficient for KD loss
kd_temperature: float | None = None # temperature for sampling during KD
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
kd_normalize_topk: bool | None = (
None # whether to normalize student logits during KD
)
# TODO online kd
kd_online_server_base_url: str | None = None
kd_online_topk: int | None = None
kd_online_server: InferenceServerType | None = Field(
default_factory=lambda: InferenceServerType.vllm
)
kd_online_timeout: int | None = 120
kd_temperature_min: float | None = (
None # kd temperature scheduling during online kd
)
@dataclass
class KDTrainingArgsMixin:
"""
Additional args for KD training.
"""
kd_ce_alpha: float | None = (
None # loss coefficient for cross-entropy loss during KD
)
kd_alpha: float | None = None # loss coefficient for KD loss
kd_temperature: float | None = None # temperature for sampling during KD
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
kd_normalize_topk: float | None = (
None # whether to normalize student logits during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[bool] = (
None # whether to sample top k before softmax during KD
)

View File

@@ -1,36 +0,0 @@
"""
Transformers trainer callbacks to schedule the KD temperature during training
"""
import math
from transformers.trainer_callback import TrainerCallback
class KDTemperatureSchedulerCallback(TrainerCallback):
"""
KD temperature scheduler callback for the trainer.
"""
def __init__(self, temperature_start, temperature_min, trainer):
self.temperature_start = temperature_start
self.temperature_min = temperature_min
self.temperature = temperature_start
self.trainer = trainer
def on_step_end(
self, args, state, control, **kwargs
): # pylint: disable=unused-argument
# cosine decay temperature over the max steps
progress = state.global_step / state.max_steps
# Cosine decay factor: 0.5 * (1 + cos(pi * progress))
# This factor goes from 1 (at progress=0) to 0 (at progress=1)
decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
self.temperature = self.temperature_start - (
(self.temperature_start - self.temperature_min) * (1.0 - decay_factor)
)
if hasattr(self.trainer.data_collator, "kd_temperature"):
self.trainer.data_collator.kd_temperature = self.temperature

View File

@@ -15,15 +15,12 @@
"""
Chat template prompt strategy loader with KD support
"""
import logging
from typing import Any, Dict
import torch
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
LOG = logging.getLogger(__name__)
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
"""
@@ -104,8 +101,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
# we shift for causal models in the trainer, so start the range from 0
for _ in range(0, input_padding_len):
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer
shift = 0
for _ in range(shift, input_padding_len):
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
@@ -144,10 +143,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
# normalize probabilities to sum to 1 in case they aren't already
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
if teacher_probs_t1_sum > 1e-9:
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
if self.kd_temperature != self.gen_temperature:
# Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature
@@ -167,115 +162,12 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids)
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs
sample["target_token_ids"] = target_token_ids
sample["target_mask"] = target_mask
return sample
def _tokenize_single_prompt(self, prompt):
logprobs = prompt.pop(self.logprobs_field)
tokenized_prompt = super()._tokenize_single_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
return tokenized_prompt
class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
"""
Strat for datasets with complete structured KD logprob data
"""
def transform_logprobs(self, sample):
"""
Transform logprobs to target format for KD training
"""
# pylint: disable=duplicate-code
logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len
# get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [
len(logprobs[i])
for i in range(len(logprobs))
if logprobs[i] is not None and len(logprobs[i])
]
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
top_k = min(max_top_k, min_top_k)
if top_k == 0:
raise ValueError("No non-zero top-k logprobs found.")
target_logprobs = []
target_token_ids = []
target_mask = []
if input_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
# target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k
logprobs = [row[:top_k] for row in logprobs]
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
# we shift for causal models in the trainer, so start the range from 0
for _ in range(0, input_padding_len):
if shift == 1:
# since we started at index 1 for causal, we need one more padding token
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
for position in range(input_padding_len, input_seq_len):
if sample["labels"][position] == -100:
target_mask.append([0] * top_k)
else:
target_mask.append([1] * top_k)
for token_pos_logprobs, pos_target_token_ids in zip(
logprobs, sample["target_token_ids"]
):
# Convert to a tensor for easier manipulation
position_logprobs_tensor = torch.tensor(
token_pos_logprobs, dtype=torch.float
)
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
# normalize probabilities to sum to 1 in case they aren't already
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
if teacher_probs_t1_sum > 1e-9:
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
if self.kd_temperature != self.gen_temperature:
# Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature
teacher_probs_t2 = teacher_probs_t1**exponent
else:
teacher_probs_t2 = teacher_probs_t1
# Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True
)
# Convert back to log
position_logprobs_tensor = torch.log(teacher_probs_t2)
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
position_logprobs_scaled = position_logprobs_tensor.tolist()
target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(pos_target_token_ids)
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs
sample["target_token_ids"] = target_token_ids
@@ -285,10 +177,8 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
def _tokenize_single_prompt(self, prompt):
logprobs = prompt.pop(self.logprobs_field)
target_token_ids = prompt.pop("target_token_ids")
tokenized_prompt = super()._tokenize_single_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt["target_token_ids"] = target_token_ids
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
return tokenized_prompt
@@ -300,7 +190,7 @@ class KDStrategyLoader(StrategyLoader):
"""
def _get_strategy_cls(self):
return ChatTemplateStrategyWithKDv2
return ChatTemplateStrategyWithKD
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
strategy_params = super()._get_strategy_params(cfg, ds_cfg)

View File

@@ -47,16 +47,11 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
position_pad_token_id: int = 0
return_tensors: str = "pt"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
padding_side = self.tokenizer.padding_side
max_len = 0
# Pad labels and position_ids first
for feature_name, pad_token_id in [
@@ -107,9 +102,7 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
target_mask_list.append(f.pop("target_mask"))
# Determine max lengths
max_teacher_seq_len = max_len or max(
len(seq) for seq in target_logprobs_list
)
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
padded_target_logprobs = []
@@ -216,9 +209,7 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
# We want to produce a single "merged" feature dict for each sub-batch.
out_features = [{} for _ in features]
for i, sub_features in enumerate( # pylint: disable=too-many-nested-blocks
features
):
for i, sub_features in enumerate(features):
# sub_features is a list of dicts, each dict = one sequences features
# We'll merge them into out_features[i].
#
@@ -252,17 +243,10 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
# For example, input_ids or labels are often arrays.
arrays = []
for feat in sub_features:
if field_name in feat and isinstance(
feat[field_name], (list, torch.Tensor)
):
if isinstance(
feat[field_name][0], (dict, str)
): # pylint: disable=too-many-nested-blocks
continue
if field_name in feat:
arr = np.array(feat[field_name])
arrays.append(arr)
if arrays:
out_features[i][field_name] = np.concatenate(arrays)
out_features[i][field_name] = np.concatenate(arrays)
# 3) Now call the parent collator, which will do:
# - padding of labels/position_ids

View File

@@ -1,561 +0,0 @@
"""
Packed data loader for online teacher training supporting vllm and sglang.
"""
import hashlib
import hmac
import logging
from typing import Any, Dict, List, Optional
import requests
import torch
from orjson import orjson
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
from axolotl.integrations.kd.utils import normalize_logprobs
from axolotl.utils.data.utils import retry_on_request_exceptions
LOG = logging.getLogger(__name__)
def hmac_sha_from_int_list(int_list, key, hash_func=hashlib.sha256):
"""
Create HMAC-SHA hash from a list of integers
Args:
int_list: List of integers
key: Secret key (string or bytes)
hash_func: Hash function (default: sha256)
Returns:
HMAC digest as hex string
"""
# Convert key to bytes if it's a string
if isinstance(key, str):
key = key.encode("utf-8")
# Convert list of ints to bytes
# Method 1: Convert each int to bytes and concatenate
data = b"".join(i.to_bytes(4, byteorder="big") for i in int_list)
# Create HMAC
h = hmac.new(key, data, hash_func)
return h.hexdigest()
class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
"""
Collator for online teacher training.
"""
DEFAULT_LABEL_PAD_TOKEN_ID: int = -100
def __init__(
self,
*args: Any,
kd_online_server_base_url: Optional[str] = None,
kd_online_topk: Optional[int] = None,
kd_temperature: Optional[float] = 1.0,
kd_online_server: Optional[str] = "vllm",
kd_online_timeout: Optional[int] = 120,
kd_cache_dir: Optional[str] = None,
kd_normalize_topk: Optional[bool] = True,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
if kd_online_server_base_url is None:
raise ValueError(
"kd_online_server_base_url must be provided for OnlineTeacherDataloader"
)
if kd_online_topk is None or kd_online_topk <= 0:
raise ValueError(
"kd_online_topk must be a positive integer for OnlineTeacherDataloader"
)
self.kd_online_server_base_url = kd_online_server_base_url.rstrip("/")
self.kd_online_topk = kd_online_topk
self.kd_temperature = kd_temperature
self.kd_online_server = kd_online_server
self.http_session = requests.Session()
self.kd_online_timeout = kd_online_timeout
self.kd_cache_dir = kd_cache_dir
self.kd_normalize_topk = kd_normalize_topk
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
"""
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
"""
if not raw_logprobs or self.kd_online_topk == 0:
return (
[-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
)
raw_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
return normalize_logprobs(raw_logprobs_tensor, self.kd_online_topk).tolist()
@retry_on_request_exceptions(max_retries=10, delay=5)
def fetch_online_logprobs_sglang(
self, batch_input_ids: List[List[int]], labels: List[List[int]]
):
"""
Fetches logprobs from an online teacher served by sglang for a batch of input_ids.
Assumes API returns token IDs as strings in logprob dictionary keys.
"""
api_endpoint = f"{self.kd_online_server_base_url}/generate"
payload = {
"input_ids": batch_input_ids,
"return_logprob": True,
"top_logprobs_num": self.kd_online_topk,
"logprob_start_len": 0,
"return_text_in_logprobs": True,
"echo": True,
"sampling_params": {
"max_new_tokens": 0,
"temperature": self.kd_temperature,
"skip_special_tokens": False,
},
}
# Initialize with empty lists, so if API call fails, these are returned.
ret_data_target_token_ids: List[List[List[int]]] = []
ret_data_target_logprobs: List[List[List[float]]] = []
ret_data_target_mask: List[List[List[int]]] = []
try:
response = self.http_session.post(
api_endpoint, json=payload, timeout=self.kd_online_timeout
)
response.raise_for_status()
api_data: list[dict] = response.json()
# Ensure api_data is a list, and its length matches batch_input_ids
if not isinstance(api_data, list) or len(api_data) != len(batch_input_ids):
LOG.error(
f"API response format error. Expected a list of {len(batch_input_ids)} "
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
)
# Return empty data; items processed later will get default empty KD fields
return {
"target_token_ids": ret_data_target_token_ids,
"target_logprobs": ret_data_target_logprobs,
"target_mask": ret_data_target_mask,
}
for sequence_data, seq_input_ids, seq_labels in zip(
api_data, batch_input_ids, labels
):
current_target_logprobs = []
current_target_token_ids = []
current_target_mask = []
meta_info = sequence_data.pop("meta_info", {})
# Ensure input_top_logprobs is a list
input_top_logprobs: Optional[list[None | list[tuple]]] = meta_info.pop(
"input_top_logprobs", []
)
if not isinstance(input_top_logprobs, list):
LOG.warning(
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
)
input_top_logprobs = [] # Treat as empty
# basic check that the logprob data len matches the input len, so no need to handle padding
assert len(seq_input_ids) == len(input_top_logprobs)
for i, _, label in zip(
range(len(seq_input_ids)), seq_input_ids, seq_labels
):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token.
# there is never logprob data for the first token since that's a true input
# so we replace the None value with padding data
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
elif (
i < len(input_top_logprobs)
and input_top_logprobs[i] is not None
):
pos_top_logprobs_data = input_top_logprobs[i]
# Ensure pos_top_logprobs_data is a list of lists as expected
if not (
isinstance(pos_top_logprobs_data, list)
and all(
isinstance(item, list) for item in pos_top_logprobs_data
)
and len(pos_top_logprobs_data) > 0
and len(pos_top_logprobs_data[0]) == 3
): # [logprob, token_id, token_str]
LOG.warning(
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
continue
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
pos_logprobs_raw, pos_token_ids, _ = [
list(row) for row in zip(*pos_top_logprobs_data)
]
# Ensure correct length (top_k)
if len(pos_logprobs_raw) < self.kd_online_topk:
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
pos_logprobs_raw.extend([-float("inf")] * pad_len)
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
# truncate to top_k in case the response was longer
current_target_token_ids.append(
pos_token_ids[: self.kd_online_topk]
)
if self.kd_normalize_topk:
normalized_logprobs_for_position = self._normalize_logprobs(
pos_logprobs_raw[: self.kd_online_topk]
)
current_target_logprobs.append(
normalized_logprobs_for_position
)
else:
current_target_logprobs.append(
pos_logprobs_raw[: self.kd_online_topk]
)
# Mask depends on the corresponding label for the student
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
current_target_mask.append([0] * self.kd_online_topk)
else:
current_target_mask.append([1] * self.kd_online_topk)
else:
# Pad if no logprobs for this position (either due to length mismatch or None entry)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
ret_data_target_token_ids.append(current_target_token_ids)
ret_data_target_logprobs.append(current_target_logprobs)
ret_data_target_mask.append(current_target_mask)
except requests.exceptions.RequestException as e:
LOG.error(f"Error fetching logprobs from online teacher: {e}")
raise e
# ret_logprobs_data will be returned with empty lists, handled by the caller.
except Exception as e: # Catch other potential errors during processing
LOG.error(
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
exc_info=True,
)
raise e
return {
"target_token_ids": ret_data_target_token_ids,
"target_logprobs": ret_data_target_logprobs,
"target_mask": ret_data_target_mask,
}
@retry_on_request_exceptions(max_retries=10, delay=5)
def fetch_online_logprobs_vllm(
self, batch_input_ids: List[List[int]], labels: List[List[int]]
):
"""
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
Assumes API returns token IDs as strings in logprob dictionary keys.
"""
api_endpoint = f"{self.kd_online_server_base_url}/v1/completions"
payload = {
"prompt": batch_input_ids,
"echo": True,
"logprobs": True,
"prompt_logprobs": self.kd_online_topk,
"top_logprobs": self.kd_online_topk,
"max_new_tokens": 0,
"skip_special_tokens": False,
"temperature": self.kd_temperature,
"sampling_params": {
"max_tokens": 0,
},
}
# Initialize with empty lists, so if API call fails, these are returned.
ret_data_target_token_ids: List[List[List[int]]] = []
ret_data_target_logprobs: List[List[List[float]]] = []
ret_data_target_mask: List[List[List[int]]] = []
try:
headers = {"Accept-Encoding": "deflate, gzip, br, zstd"}
response = self.http_session.post(
api_endpoint,
json=payload,
headers=headers,
timeout=self.kd_online_timeout,
)
response.raise_for_status()
api_data: dict = orjson.loads(response.content)
choices: list[dict] = api_data["choices"]
# Ensure api_data is a list, and its length matches batch_input_ids
if not isinstance(choices, list) or len(choices) != len(batch_input_ids):
LOG.error(
f"API response format error. Expected a list of {len(batch_input_ids)} "
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
)
# Return empty data; items processed later will get default empty KD fields
return {
"target_token_ids": ret_data_target_token_ids,
"target_logprobs": ret_data_target_logprobs,
"target_mask": ret_data_target_mask,
}
for sequence_data, seq_input_ids, seq_labels in zip(
choices, batch_input_ids, labels
):
# seq_input_ids: List[int]
# seq_labels: List[int]
current_target_logprobs = []
current_target_token_ids = []
current_target_mask = []
# Ensure input_top_logprobs is a list
input_top_logprobs: Optional[list[None | dict[str, dict]]] = (
sequence_data.pop("prompt_logprobs", [])
)
if not isinstance(input_top_logprobs, list):
LOG.warning(
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
)
input_top_logprobs = [] # Treat as empty
# basic check that the logprob data len matches the input len, so no need to handle padding
assert len(seq_input_ids) == len(input_top_logprobs)
seq_len = len(seq_input_ids)
for i, _, label in zip(range(seq_len), seq_input_ids, seq_labels):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token.
# there is never logprob data for the first token since that's a true input
continue
if (
i < len(input_top_logprobs)
and input_top_logprobs[i] is not None
):
pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] # type: ignore[assignment]
# Ensure pos_top_logprobs_data is a list of lists as expected
if not (
isinstance(pos_top_logprobs_data, dict)
and all(
isinstance(item, dict)
for item in pos_top_logprobs_data.values()
)
and len(pos_top_logprobs_data.keys()) > 0
): # [logprob, token_id, token_str]
LOG.warning(
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk)
continue
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
pos_token_ids_str = list(pos_top_logprobs_data.keys())
pos_logprobs_dict = pos_top_logprobs_data.values()
pos_token_ids = [
int(token_id) for token_id in pos_token_ids_str
]
pos_logprobs_raw = [
float(logprob.get("logprob", -float("inf")))
for logprob in pos_logprobs_dict
]
# Ensure correct length (top_k)
if len(pos_logprobs_raw) < self.kd_online_topk:
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
LOG.warning(
f"Padding position {i} with {pad_len} top-k tokens and logprobs."
)
pos_logprobs_raw.extend([-float("inf")] * pad_len)
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
# truncate to top_k in case the response was longer
current_target_token_ids.append(
pos_token_ids[: self.kd_online_topk]
)
if self.kd_normalize_topk:
normalized_logprobs_for_position = self._normalize_logprobs(
pos_logprobs_raw[: self.kd_online_topk]
)
current_target_logprobs.append(
normalized_logprobs_for_position
)
else:
current_target_logprobs.append(
pos_logprobs_raw[: self.kd_online_topk]
)
# Mask depends on the corresponding label for the student
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
current_target_mask.append([0] * self.kd_online_topk)
else:
current_target_mask.append([1] * self.kd_online_topk)
else:
# Pad if no logprobs for this position (either due to length mismatch or None entry)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk)
for i in range(max(0, seq_len - len(current_target_logprobs))):
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(list(range(self.kd_online_topk)))
current_target_mask.append([0] * self.kd_online_topk)
ret_data_target_token_ids.append(current_target_token_ids)
ret_data_target_logprobs.append(current_target_logprobs)
ret_data_target_mask.append(current_target_mask)
# TODO save and load targets to disk for caching for next epoch
# generate a hmac SHA256 hash over the list seq_input_ids and convert it to an int
# if self.kd_cache_dir:
# hash_input_ids = hmac_sha_from_int_list(
# seq_input_ids, f"{self.kd_online_server_base_url}:{self.kd_online_topk}"
# )
# with open(f"{self.kd_cache_dir}/{hash_input_ids}.parquet", "wb") as f:
# pd.DataFrame(ret_logprobs_data).to_parquet(f, index=False)
except requests.exceptions.RequestException as e:
LOG.error(f"Error fetching logprobs from online teacher: {e}")
raise e
# ret_logprobs_data will be returned with empty lists, handled by the caller.
except Exception as e: # Catch other potential errors during processing
LOG.error(
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
exc_info=True,
)
raise e
return {
"target_token_ids": ret_data_target_token_ids,
"target_logprobs": ret_data_target_logprobs,
"target_mask": ret_data_target_mask,
}
def __call__(
self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None
) -> Dict[str, Any]:
if not features:
return super().__call__(features, return_tensors=return_tensors)
for (
sub_batch_features
) in features: # sub_batch_features is List[Dict[str, Any]]
if not sub_batch_features:
continue
input_ids_for_api_call: List[List[int]] = []
labels_for_api_call: List[List[int]] = []
# Store references to the original item dictionaries to update them in-place
items_for_api_call: List[Dict[str, Any]] = []
for item_dict in sub_batch_features:
if not isinstance(item_dict, dict):
LOG.warning(
f"Skipping non-dict item in sub_batch_features: {item_dict}"
)
continue
current_input_ids = item_dict.get("input_ids")
current_labels = item_dict.get("labels")
if current_input_ids is not None and current_labels is not None:
# Ensure input_ids and labels are lists of ints for JSON serialization
input_ids_list = (
current_input_ids.tolist()
if hasattr(current_input_ids, "tolist")
else list(current_input_ids)
)
labels_list = (
current_labels.tolist()
if hasattr(current_labels, "tolist")
else list(current_labels)
)
input_ids_for_api_call.append(input_ids_list)
labels_for_api_call.append(labels_list)
items_for_api_call.append(item_dict)
else:
# This item will not get teacher logprobs from the API.
# Initialize KD fields to empty lists so downstream collators handle them uniformly.
item_dict.setdefault("target_token_ids", [])
item_dict.setdefault("target_logprobs", [])
item_dict.setdefault("target_mask", [])
# print(items_for_api_call)
if items_for_api_call: # Only call API if there's something to process
if self.kd_online_server == "sglang":
api_responses_for_sub_batch = self.fetch_online_logprobs_sglang(
input_ids_for_api_call, labels_for_api_call
)
else:
api_responses_for_sub_batch = self.fetch_online_logprobs_vllm(
input_ids_for_api_call, labels_for_api_call
)
# api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask"
# Each value is a list, corresponding to items_for_api_call
for i, item_to_update in enumerate(items_for_api_call):
# TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly.
if api_responses_for_sub_batch and i < len(
api_responses_for_sub_batch["target_token_ids"]
): # Check bounds
assert len(
api_responses_for_sub_batch["target_token_ids"][i]
) == len(item_to_update["input_ids"])
assert len(
api_responses_for_sub_batch["target_logprobs"][i]
) == len(item_to_update["input_ids"])
assert len(
api_responses_for_sub_batch["target_mask"][i]
) == len(item_to_update["labels"])
item_to_update["target_token_ids"] = (
api_responses_for_sub_batch["target_token_ids"][i]
)
item_to_update["target_logprobs"] = api_responses_for_sub_batch[
"target_logprobs"
][i]
item_to_update["target_mask"] = api_responses_for_sub_batch[
"target_mask"
][i]
else:
# API call failed for this item, or response was shorter than expected.
# Ensure KD fields are initialized as empty lists.
LOG.warning(
f" (index {i}), or API response was too short. "
f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}"
)
item_to_update.setdefault("target_token_ids", [])
item_to_update.setdefault("target_logprobs", [])
item_to_update.setdefault("target_mask", [])
return super().__call__(features, return_tensors=return_tensors)

View File

@@ -1,485 +0,0 @@
"""
Liger Kernels for Chunked Top-K Log-Prob Distillation
"""
import torch
import torch.nn.functional as F
from liger_kernel.chunked_loss.fused_linear_distillation import (
LigerFusedLinearDistillationBase,
)
from axolotl.integrations.kd.utils import normalize_logprobs
class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
"""
Chunked kl-div loss for top-k logprobs
"""
@staticmethod
def distillation_loss_fn(
student_logits_temp_scaled: torch.Tensor, # [chunk_size, vocab_size], already temp-scaled
target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k]
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
beta: float = 0.0,
normalize_topk: bool = True,
) -> torch.Tensor:
"""
Compute Top-K KL divergence loss for a chunk.
Args:
student_logits_temp_scaled: Student logits, scaled by temperature. Shape: (N, V).
target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K).
target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K).
target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K).
beta: Controls the type of KL divergence.
0.0 for Forward KL (P_teacher || P_student).
1.0 for Reverse KL (P_student || P_teacher).
0.5 for Symmetric KL (average of Forward and Reverse).
normalize_topk: Whether to normalize the log probabilities
Returns:
Sum of KL divergence losses for the chunk.
"""
topk = target_token_ids_chunk.shape[-1]
student_logits_temp_scaled = ( # [chunk_size, vocab_size]
student_logits_temp_scaled.float()
)
target_logprobs_chunk = target_logprobs_chunk.float()
# Gather student logits for the top-k teacher token IDs
# target_token_ids_chunk: [chunk_size, top_k]
# student_logits_topk_temp_scaled: [chunk_size, top_k]
student_logits_topk_temp_scaled = torch.gather(
student_logits_temp_scaled, dim=-1, index=target_token_ids_chunk
)
# Student log-probabilities for the gathered top-k tokens
student_lse = torch.logsumexp(
student_logits_temp_scaled, dim=-1, keepdim=True
) # [chunk_size, 1]
student_logprobs_topk_temp_scaled = (
student_logits_topk_temp_scaled - student_lse
)
# we have the top-k student logprobs, normalize them
if normalize_topk:
student_logprobs_topk_temp_scaled = normalize_logprobs(
student_logprobs_topk_temp_scaled, topk
)
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
teacher_logprobs_valid = target_logprobs_chunk[valid_mask]
# Teacher probabilities P(y|x_teacher) from logprobs
# target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))
teacher_probs_valid = teacher_logprobs_valid.exp()
# Student probabilities P_student from log P_student
student_probs_topk_valid = student_logprobs_topk_valid.exp()
# kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
# KL divergence: sum(P_teacher * (log P_teacher - log P_student))
# = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
# The distillation loss is often formulated as -sum(P_teacher * log P_student)
# or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
# Here, target_logprobs_valid are log_softmax_teacher.
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
if beta == 0.0: # Contribution from Forward KL
fwd_kl_per_token = teacher_probs_valid * (
teacher_logprobs_valid - student_logprobs_topk_valid
)
kd_loss = fwd_kl_per_token.sum()
elif beta == 1.0: # Contribution from Reverse KL
rev_kl_per_token = student_probs_topk_valid * (
student_logprobs_topk_valid - teacher_logprobs_valid
)
kd_loss = rev_kl_per_token.sum()
else:
# JSD - Jensen-Shannon Divergence / Symmetric
mean_probs = (
1 - beta
) * student_probs_topk_valid + beta * teacher_probs_valid
log_mean_probs = mean_probs.log()
student_kl = F.kl_div(
log_mean_probs,
student_logprobs_topk_valid,
reduction="sum",
log_target=True,
)
teacher_kl = F.kl_div(
log_mean_probs, teacher_logprobs_valid, reduction="sum", log_target=True
)
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
kd_loss = jsd_loss
return kd_loss
@staticmethod
def _compute_loss_kl_topk(
student_input_chunk: torch.Tensor,
student_weight: torch.Tensor,
# Args for student_bias, target_token_ids_chunk etc. are passed to the lambda wrapped by grad_and_value
# or through `partial`. Let's make them explicit here for clarity.
target_token_ids_chunk: torch.Tensor,
target_logprobs_chunk: torch.Tensor,
target_mask_chunk: torch.Tensor,
target_chunk: torch.Tensor, # For hard loss (true labels)
student_bias: torch.Tensor = None, # This will be one of the grad targets
# Other params passed via `partial` from `forward`
distillation_loss_fn=None,
ignore_index: int = -100,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
compute_ce_loss: bool = True,
temperature: float = 1.0,
beta: float = 0.0,
normalize_topk: bool = True,
):
# Compute student logits for the chunk from hidden states and LM head
# student_input_chunk: [chunk_size, hidden_dim]
# student_lm_head_weight: [vocab_size, hidden_dim]
# student_logits_chunk: [chunk_size, vocab_size]
student_logits_chunk = F.linear(
student_input_chunk, student_weight, student_bias
)
ce_loss = torch.tensor(
0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
)
if compute_ce_loss and weight_hard_loss > 0.0:
ce_loss = F.cross_entropy(
student_logits_chunk.view(-1, student_logits_chunk.shape[-1]),
target_chunk.view(-1),
reduction="sum",
ignore_index=ignore_index,
)
soft_loss = torch.tensor(
0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
)
if weight_soft_loss > 0.0:
student_logits_chunk_temp_scaled = student_logits_chunk / temperature
# Assuming student_weight.shape[0] (vocab_size) is adequate for target_token_ids_chunk.max()
# No explicit padding here; user must ensure vocab alignment or pre-pad student_weight.
soft_loss = distillation_loss_fn(
student_logits_chunk_temp_scaled,
target_token_ids_chunk,
target_logprobs_chunk,
target_mask_chunk,
beta=beta,
normalize_topk=normalize_topk,
)
return soft_loss, ce_loss
@classmethod
def forward(
cls,
ctx,
student_input: torch.Tensor, # [batch_size, seq_len, dim]
student_lm_head_weight: torch.Tensor, # [dim, vocab_size]
target_token_ids: torch.Tensor, # [batch_size, seq_len, top_k]
target_logprobs: torch.Tensor, # [batch_size, seq_len, top_k]
target_mask: torch.Tensor, # [batch_size, seq_len, top_k]
true_labels: torch.Tensor, # [batch_size, seq_len]
student_lm_head_bias: torch.Tensor = None,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
beta: float = 0.0,
compiled: bool = False,
chunk_size: int = 1024,
compute_ce_loss: bool = True,
normalize_topk: bool = True,
):
CHUNK_SIZE = chunk_size # pylint: disable=invalid-name
grad_weight_acc = torch.zeros_like(student_lm_head_weight)
grad_inputs_list = []
grad_bias_acc = (
torch.zeros_like(student_lm_head_bias)
if student_lm_head_bias is not None
else None
)
kd_loss_acc = torch.zeros(
(), device=student_input.device, dtype=student_input.dtype
)
ce_loss_acc = torch.zeros(
(), device=student_input.device, dtype=student_input.dtype
)
# This function will be what torch.func.grad_and_value differentiates.
# It takes student_input_chunk, student_weight (full), student_bias (full) as primals.
# Other necessary data (target_*, etc.) are passed as non-differentiable arguments.
def loss_fn_for_grad(
_student_input_chunk,
_student_lm_head_weight, # full weight
_student_lm_head_bias, # full bias
# Fixed arguments for a given chunk, not differentiated:
_target_token_ids_chunk,
_target_logprobs_chunk,
_target_mask_chunk,
_true_labels_chunk,
):
return cls._compute_loss_kl_topk(
student_input_chunk=_student_input_chunk,
student_weight=_student_lm_head_weight,
target_token_ids_chunk=_target_token_ids_chunk,
target_logprobs_chunk=_target_logprobs_chunk,
target_mask_chunk=_target_mask_chunk,
target_chunk=_true_labels_chunk,
student_bias=_student_lm_head_bias,
distillation_loss_fn=cls.distillation_loss_fn,
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
compute_ce_loss=compute_ce_loss,
temperature=temperature,
beta=beta,
normalize_topk=normalize_topk,
)
def accumulate_chunk_grads(
student_input_chunk_ac,
target_token_ids_chunk_ac,
target_logprobs_chunk_ac,
target_mask_chunk_ac,
true_labels_chunk_ac,
):
# student_weight and student_bias are closed over from the outer scope (full tensors)
if student_lm_head_bias is not None:
(
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
(chunk_kd_loss, chunk_ce_loss),
) = torch.func.grad_and_value(
loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True
)(
student_input_chunk_ac,
student_lm_head_weight,
student_lm_head_bias, # primals
target_token_ids_chunk_ac,
target_logprobs_chunk_ac,
target_mask_chunk_ac,
true_labels_chunk_ac,
) # non-primals
grad_bias_acc.add_(chunk_grad_bias)
else:
argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight
(
(chunk_grad_input, chunk_grad_weight), # No grad for bias
(chunk_kd_loss, chunk_ce_loss),
) = torch.func.grad_and_value(
loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True
)(
student_input_chunk_ac,
student_lm_head_weight,
None, # Pass None for student_bias primal
target_token_ids_chunk_ac,
target_logprobs_chunk_ac,
target_mask_chunk_ac,
true_labels_chunk_ac,
)
grad_weight_acc.add_(chunk_grad_weight)
kd_loss_acc.add_(chunk_kd_loss)
ce_loss_acc.add_(chunk_ce_loss)
return chunk_grad_input
if compiled:
accumulate_chunk_grads_compiled = torch.compile(
accumulate_chunk_grads, dynamic=True, backend="inductor"
) # dynamic=True often helpful
else:
accumulate_chunk_grads_compiled = accumulate_chunk_grads
# Use the same chunking logic as LigerFusedLinearDistillationBase.forward
B, N, D = student_input.shape # pylint: disable=invalid-name
K = target_token_ids.shape[-1] # pylint: disable=invalid-name
student_input_flat = student_input.reshape(-1, student_input.shape[-1])
target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1])
target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1])
target_mask_flat = target_mask.reshape(-1, target_mask.shape[-1])
# pad and shift for cross entropy loss
true_labels = torch.nn.functional.pad(true_labels, (0, 1), value=ignore_index)
true_labels_flat = true_labels[:, 1:].contiguous().view(-1)
num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE)
_student_input_chunks = torch.chunk(
student_input_flat, chunks=num_chunks, dim=0
)
_target_token_ids_chunks = torch.chunk(
target_token_ids_flat, chunks=num_chunks, dim=0
)
_target_logprobs_chunks = torch.chunk(
target_logprobs_flat, chunks=num_chunks, dim=0
)
_target_mask_chunks = torch.chunk(target_mask_flat, chunks=num_chunks, dim=0)
_true_labels_chunks = torch.chunk(true_labels_flat, chunks=num_chunks, dim=0)
for i in range(num_chunks):
grad_input_chunk = accumulate_chunk_grads_compiled(
_student_input_chunks[i],
_target_token_ids_chunks[i],
_target_logprobs_chunks[i],
_target_mask_chunks[i],
_true_labels_chunks[i],
)
grad_inputs_list.append(grad_input_chunk)
grad_inputs_combined = torch.cat(grad_inputs_list, dim=0)
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
ctx.hyperparams_count = 9 # Corresponds to number of hyperparams after main tensors in fwd signature
ctx.bias_was_none = student_lm_head_bias is None
ctx.orig_dims = (B, N, D, K)
# since this is packed, there is simply a single batch, so batchmean reduction of kl-div is simply the accumulated sum
# we still need to scale the kd_loss by the temp^2
kd_loss_acc = kd_loss_acc * (temperature**2)
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
return final_loss
@staticmethod
def backward(ctx, grad_output):
grad_input_flat, grad_weight, grad_bias_maybe = (
ctx.saved_tensors
) # grad_input_flat is (B*N, D)
# Scale gradients by grad_output if it's not 1.0
if not torch.equal(
grad_output,
torch.tensor(1.0, device=grad_output.device, dtype=grad_output.dtype),
):
grad_input_flat = grad_input_flat * grad_output
grad_weight = grad_weight * grad_output
if grad_bias_maybe is not None:
grad_bias_maybe = grad_bias_maybe * grad_output
# Reshape grad_input_flat to match original student_input shape (B, N, D)
# ctx.orig_dims stores (B, N, D, K)
# We need the first three dimensions for student_input's shape.
# Ensure that orig_dims are not (0,0,0,K) for empty inputs leading to view errors
if (
ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
and grad_input_flat.numel() == 0
):
# If original input was empty, gradient should also be empty with correct shape
grad_input_reshaped = torch.zeros(
ctx.orig_dims[0],
ctx.orig_dims[1],
ctx.orig_dims[2],
dtype=grad_input_flat.dtype,
device=grad_input_flat.device,
)
elif grad_input_flat.numel() == 0 and not (
ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
):
# This case should ideally not happen if forward path is correct (non-empty input -> non-empty flat grad)
# but as a safeguard:
grad_input_reshaped = torch.zeros(
ctx.orig_dims[0],
ctx.orig_dims[1],
ctx.orig_dims[2],
dtype=grad_input_flat.dtype,
device=grad_input_flat.device,
)
else:
grad_input_reshaped = grad_input_flat.view(
ctx.orig_dims[0], ctx.orig_dims[1], ctx.orig_dims[2]
)
nones_for_hyperparams = [None] * ctx.hyperparams_count
grad_bias_return = grad_bias_maybe if not ctx.bias_was_none else None
return (
grad_input_reshaped, # Gradient for student_input (reshaped)
grad_weight, # Gradient for student_lm_head_weight
None, # Gradient for target_token_ids
None, # Gradient for target_logprobs
None, # Gradient for target_mask
None, # Gradient for true_labels
grad_bias_return, # Gradient for student_lm_head_bias
*nones_for_hyperparams, # Grads for weight_hard_loss, ..., compute_ce_loss
)
class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
"""
wrapper for chunked top-k logprob kl-d
"""
def __init__(
self,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
temperature: float = 1.0, # This is the kd_temperature
beta: float = 1.0,
ignore_index: int = -100,
compiled: bool = True,
chunk_size: int = 1024,
compute_ce_loss: bool = True,
normalize_topk: bool = True,
):
super().__init__()
if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0):
raise ValueError("Loss weights must be between 0.0 and 1.0.")
if temperature <= 0:
raise ValueError("Temperature must be positive.")
self.weight_hard_loss = weight_hard_loss
self.weight_soft_loss = weight_soft_loss
self.temperature = temperature
self.beta = beta
self.ignore_index = ignore_index
self.compiled = compiled
self.chunk_size = chunk_size
self.compute_ce_loss = compute_ce_loss
self.normalize_topk = normalize_topk
if not self.compute_ce_loss and self.weight_hard_loss > 0.0:
print(
f"Warning: compute_ce_loss is False, but weight_hard_loss ({self.weight_hard_loss}) > 0. Hard loss will effectively be zero."
)
# self.weight_hard_loss = 0.0 # Or let user manage this
if self.weight_soft_loss == 0.0:
print(
"Warning: weight_soft_loss is 0.0. Soft (KD) loss will not be computed."
)
def forward(
self,
lm_head_weight: torch.Tensor, # Weights of the linear layer in the LM head
student_hidden_states: torch.Tensor, # student_hidden_states before the lm_head
target_token_ids: torch.Tensor,
target_logprobs: torch.Tensor,
target_mask: torch.Tensor,
true_labels: torch.Tensor,
student_bias: torch.Tensor = None,
) -> torch.Tensor:
return LigerFusedLinearKLTopKLogprobFunction.apply(
student_hidden_states,
lm_head_weight,
target_token_ids,
target_logprobs,
target_mask,
true_labels,
student_bias,
self.weight_hard_loss,
self.weight_soft_loss,
self.ignore_index,
self.temperature,
self.beta,
self.compiled,
self.chunk_size,
self.compute_ce_loss,
self.normalize_topk,
)

View File

@@ -1,97 +0,0 @@
"""
model patcher for chunked top-k kl-div
"""
from typing import Optional, Union, Unpack
import torch
from transformers import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import LossKwargs
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
"""
placeholder kwargs for hf model classes
"""
def kldiv_forward_llama_like(
self,
input_ids: Optional[torch.LongTensor] = None,
target_logprobs: Optional[torch.Tensor] = None,
target_token_ids: Optional[torch.LongTensor] = None,
target_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument
**kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc]
) -> CausalLMOutputWithPast:
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
loss = self.loss_function(
self.lm_head.weight,
hidden_states,
target_token_ids,
target_logprobs,
target_mask,
true_labels=labels,
)
num_items_in_batch = kwargs.pop("num_items_in_batch", -1)
if num_items_in_batch is not None and num_items_in_batch > 0:
loss = loss / num_items_in_batch
return CausalLMOutputWithPast(
loss=loss,
logits=None,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_kernel(model_type):
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
model_cls.forward = kldiv_forward_llama_like

View File

@@ -16,7 +16,40 @@
loss for top_k KL divergence
"""
import torch
from torch import nn
def zscore_standardize(
logits: torch.Tensor,
mask: torch.Tensor = None,
base_temperature: float = 1.0,
eps: float = 1e-9,
):
"""
Z-score standardize along the last dimension of `logits`.
i.e., for each [B, seq_len] row, across K entries:
z = (logits - mean) / std,
then scale by 1 / base_temperature if desired.
mask can be broadcastable or None. If None, we standardize all elements.
"""
if mask is None:
# shape: [B, seq_len, K]
# Mean and std over dim=-1
mean = logits.mean(dim=-1, keepdim=True)
var = logits.var(dim=-1, unbiased=False, keepdim=True)
else:
# If you have to exclude some tokens, multiply by mask, etc.
float_mask = mask.to(logits.dtype)
count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0)
mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count
var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count
std = torch.sqrt(var.clamp_min(eps))
z = (logits - mean) / std
# Scale by 1 / base_temperature
z = z / base_temperature
return z
@torch.jit.script
@@ -27,6 +60,7 @@ def loss(
target_mask: torch.Tensor,
num_items_in_batch: int = -1, # Use -1 to indicate "None"
kd_temperature: float = 1.0,
top_k_before_softmax: int = 0,
) -> torch.Tensor:
"""
A KD loss function that is TorchScript-friendly.
@@ -43,6 +77,8 @@ def loss(
num_items_in_batch (int, optional): The number of items in the batch.
kd_temperature (float, optional): The temperature for KD.
Default: 1.0
top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits
Default: 0
"""
target_logprobs = target_logprobs.float()
@@ -52,24 +88,46 @@ def loss(
# student_logits shape: [B, student_seq_len, vocab_size]
teacher_seq_len = target_token_ids.shape[1]
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = (
student_logits[:, :teacher_seq_len, :] / kd_temperature
) # [B, teacher_seq_len, vocab_size]
if top_k_before_softmax:
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, teacher_seq_len, vocab_size]
# keep in full precision for numerical stability of loss
student_logits_for_kd = student_logits_for_kd.float()
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
student_logits_topk = student_logits_topk.float()
# Compute logsumexp across full vocabulary
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
# Apply KD temperature to students logits
if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature
# Convert just the top-k logits to logprobs
student_logprobs_topk = student_logits_topk - student_lse
# Convert student top-k logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp(
student_logits_topk, dim=-1, keepdim=True
) # [B, teacher_seq_len, K]
else:
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = (
student_logits[:, :teacher_seq_len, :] / kd_temperature
) # [B, teacher_seq_len, vocab_size]
# keep in full precision for numerical stability of loss
student_logits_for_kd = student_logits_for_kd.float()
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
# Compute logsumexp across full vocabulary
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
# Convert just the top-k logits to logprobs
student_logprobs_topk = student_logits_topk - student_lse
# Convert teacher_mask to boolean for indexing
# In TorchScript, .bool() is sometimes unsupported, so we do:
@@ -86,6 +144,10 @@ def loss(
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
kd_loss = kd_loss_per_token.sum()
# Multiply by T^2 (classical KD scaling)
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Normalize by number of items (if provided) or by valid tokens
if num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
@@ -96,74 +158,80 @@ def loss(
return kd_loss
class ChunkedTopKKDLoss(nn.Module):
def topk_kd_loss_with_zscore(
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
target_token_ids: torch.Tensor, # [B, seq_len, K]
target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
kd_temperature: float = 1.0, # classic KD temperature
zscore_base_temp: float = 1.0, # from the paper
num_items_in_batch: int = -1,
):
"""
A wrapper that chunks (splits) the student and teacher outputs along the time dimension
to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies.
Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to top-K teacher logprobs.
A variant of top_k KL divergence with Z-score scaling
from "Logit Standardization in Knowledge Distillation".
"""
def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0):
super().__init__()
self.num_output_chunks = num_output_chunks
self.kd_temperature = kd_temperature
target_logprobs = target_logprobs.float()
def forward(
self,
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
target_token_ids: torch.Tensor, # [B, seq_len, K]
target_logprobs: torch.Tensor, # [B, seq_len, K]
target_mask: torch.Tensor, # [B, seq_len, K]
num_items_in_batch: int = -1, # optional batch size for normalization
) -> torch.Tensor:
B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
# 1) Gather the student's top-k logits to match teacher
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, seq_len, vocab]
student_topk_logits = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, seq_len, K]
# 1. Split along the "token" dimension (dim=1).
student_logits_chunks = student_logits.chunk(self.num_output_chunks, dim=1)
token_ids_chunks = target_token_ids.chunk(self.num_output_chunks, dim=1)
logprobs_chunks = target_logprobs.chunk(self.num_output_chunks, dim=1)
mask_chunks = target_mask.chunk(self.num_output_chunks, dim=1)
student_topk_logits = student_topk_logits.float()
# We'll accumulate a global "sum of losses" and "sum of valid tokens"
# so that our final average is consistent with the entire sequence/batch.
total_loss = 0.0
total_valid_tokens = 0
# 2) If you want to keep the "classical" T scaling, apply it first
if kd_temperature != 1.0:
student_topk_logits = student_topk_logits / kd_temperature
# 2. Loop over each chunk and compute a chunk-specific loss.
for st_chunk, tid_chunk, lp_chunk, msk_chunk in zip(
student_logits_chunks, token_ids_chunks, logprobs_chunks, mask_chunks
):
# We pass num_items_in_batch=-1 so that the kd_loss
# will average over *this chunk's* valid tokens only.
chunk_loss = loss(
student_logits=st_chunk,
target_token_ids=tid_chunk,
target_logprobs=lp_chunk,
target_mask=msk_chunk,
num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens
kd_temperature=self.kd_temperature,
)
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
# (They differ by +some_constant from real logits, but in z-score
# that constant is subtracted out anyway.)
teacher_logits_for_zscore = target_logprobs # rename variable for clarity
# kd_loss returns an average over the chunk's valid tokens.
# We want a global average in the end, so we need to reweight
# by the number of valid tokens in this chunk and keep track of the total.
chunk_valid_mask = msk_chunk.to(torch.bool)
chunk_valid_count = chunk_valid_mask.sum() # scalar tensor
# 4) Z-score teacher and student
# If target_mask is 2D, expand to 3D for the K dimension
if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
# Re-scale "chunk average" back to "chunk sum"
chunk_loss_sum = chunk_loss * chunk_valid_count
teacher_z = zscore_standardize(
teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
)
student_z = zscore_standardize(
student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
)
total_loss += chunk_loss_sum
total_valid_tokens += chunk_valid_count
# 5) Convert to log-probs for KL
teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True)
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
# 3. Normalize *once* at the end.
if num_items_in_batch > 0:
# If the user gave us a manual denominator (e.g. total items in batch),
# we divide by it. Typically used if each item is of different length.
final_loss = total_loss / float(num_items_in_batch)
else:
# Otherwise, divide by total valid tokens across all chunks.
# to get the same result as a non-chunked approach.
final_loss = total_loss / float(total_valid_tokens)
# 6) Restrict to valid tokens if needed
valid_mask = target_mask.bool() # shape [B, seq_len, K]
teacher_probs_z = teacher_logprobs_z.exp()
teacher_probs_z = teacher_probs_z[valid_mask]
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
student_logprobs_z = student_logprobs_z[valid_mask]
return final_loss
# 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] )
kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z)
kd_loss = kd_loss_per_token.sum()
# 8) If using classical KD scaling by T^2
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Optionally scale by zscore_base_temp**2 if you want (paper might differ).
# kd_loss = kd_loss * (zscore_base_temp**2)
# 9) Normalize
if num_items_in_batch is not None and num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
else:
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
return kd_loss

View File

@@ -18,7 +18,8 @@ KD trainer
from axolotl.core.trainers.base import AxolotlTrainer
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
class AxolotlKDTrainer(AxolotlTrainer):
@@ -26,18 +27,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
Custom trainer subclass for Knowledge Distillation (KD)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_accepts_loss_kwargs = True
self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
self.args.kd_ce_alpha, # hard label loss
self.args.kd_alpha, # kd loss
self.args.kd_temperature,
self.args.kd_beta,
compute_ce_loss=bool(self.args.kd_ce_alpha),
normalize_topk=self.args.kd_normalize_topk,
)
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()
columns_to_add = []
@@ -63,12 +52,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
Subclass and override for custom behavior.
"""
if (
self.args.sample_packing
and hasattr(inputs, "attention_mask")
and hasattr(inputs, "position_ids")
):
del inputs["attention_mask"]
target_logprobs = inputs.pop("target_logprobs")
target_token_ids = inputs.pop("target_token_ids")
target_mask = inputs.pop("target_mask")
seq_len = target_token_ids.shape[1]
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
@@ -76,4 +65,49 @@ class AxolotlKDTrainer(AxolotlTrainer):
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
return outputs[0]
# FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
if self.args.kd_zscore_base_temp:
loss_kd = topk_kd_loss_with_zscore(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
kd_temperature=self.args.kd_temperature,
zscore_base_temp=self.args.kd_zscore_base_temp,
num_items_in_batch=num_items_in_batch,
)
else:
loss_kd = topk_kd_loss(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
)
if self.args.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else:
loss = loss_kd
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
self.args.past_index
]
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss *= self.accelerator.num_processes
return (loss, outputs) if return_outputs else loss

View File

@@ -1,100 +0,0 @@
"""Helper KD utils"""
import math
from typing import List, Union
import numpy as np
import torch
from torch import FloatTensor, Tensor
def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
"""
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
"""
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
# This should ideally be handled by the caller ensuring correct padding/truncation first
if logprobs.shape[-1] != topk:
# pad last dimension of logprobs to match topk length with -inf
padding_len = topk - logprobs.shape[-1]
padding_tensor = torch.full(
(
*logprobs.shape[:-1],
padding_len,
), # Takes all dimensions of logprobs except the last, then appends padding_needed
float("-inf"),
dtype=logprobs.dtype,
device=logprobs.device,
)
logprobs = torch.cat((logprobs, padding_tensor), dim=-1)
# Convert logprobs at T_online to probabilities
# use log sum exp trick to avoid underflow
position_logprobs_lse = torch.logsumexp(logprobs, dim=-1, keepdim=True)
teacher_probs_t_online = torch.exp(logprobs - position_logprobs_lse)
# Normalize probabilities (sum to 1)
# This is important if the top-k from server aren't a full distribution
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=-1, keepdim=True)
teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum
final_logprobs_tensor = torch.log(teacher_probs_t_online)
return final_logprobs_tensor
def strided_chunk_views(
tensor: Union[np.ndarray, torch.Tensor],
chunks: int,
dim: int = 0,
stride: int = 1,
chunk_size: int | None = None,
) -> List[Union[np.ndarray, torch.Tensor]]:
"""
Split a tensor into chunks along a dimension with striding, prioritizing views over copies.
Args:
tensor: Input tensor (numpy array or torch tensor)
chunks: Number of chunks to create
dim: Dimension along which to chunk (default: 0)
stride: Stride between chunk starting positions (default: 1)
chunk_size: Size of each chunk. If None, calculated automatically (default: None)
Returns:
List of tensor chunks (views when possible, copies when necessary)
"""
# Get the size of the specified dimension
dim_size = tensor.shape[dim]
# Calculate chunk size if not provided
if chunk_size is None:
chunk_size = (dim_size + chunks - 1) // chunks # Ceiling division
chunks_list = []
for i in range(chunks):
start_idx = i * stride
end_idx = min(start_idx + chunk_size, dim_size)
# Break if we've gone beyond the tensor
if start_idx >= dim_size:
break
# Create slice objects for all dimensions
slices = [slice(None)] * tensor.ndim
slices[dim] = slice(start_idx, end_idx)
chunk = tensor[tuple(slices)]
chunks_list.append(chunk)
return chunks_list
def chunk_overlap(input_tensor: Tensor, chunks: int, dim: int = 0, overlap: int = 1):
dim_size = input_tensor.shape[dim]
stride = math.ceil(dim_size / chunks)
return strided_chunk_views(
input_tensor, chunks, dim, stride=stride, chunk_size=stride + overlap
)

View File

@@ -19,6 +19,7 @@ from peft import (
from transformers import PreTrainedModel
from axolotl.loaders.utils import get_linear_embedding_layers
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -162,6 +163,7 @@ def load_lora(
return model, lora_config
@send_errors
def load_adapter(
model: PreTrainedModel,
cfg: DictDefault,

View File

@@ -46,6 +46,7 @@ from axolotl.loaders.utils import (
load_model_config,
)
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.telemetry.errors import send_errors
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
@@ -145,6 +146,7 @@ class ModelLoader:
"""Property that determines if FSDP with QLoRA is enabled."""
return self.cfg.fsdp and self.cfg.adapter == "qlora"
@send_errors
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
"""Load and prepare the model with all configurations and patches.

View File

@@ -8,12 +8,14 @@ from transformers import (
PreTrainedTokenizerBase,
)
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@send_errors
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
processor_kwargs: dict[str, Any] = {} # Do we actually need this?

View File

@@ -12,6 +12,7 @@ from transformers import (
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.telemetry.errors import send_errors
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.distributed import (
barrier,
@@ -117,6 +118,7 @@ def modify_tokenizer_files(
return tokenizer_dir
@send_errors
def load_tokenizer(cfg):
"""Load and configure the tokenizer based on the provided config."""
model_config = load_model_config(cfg)

View File

View File

@@ -0,0 +1,164 @@
"""Trainer callbacks for reporting runtime metrics at regular intervals."""
import logging
import time
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from axolotl.telemetry.manager import TelemetryManager
from axolotl.telemetry.runtime_metrics import RuntimeMetricsTracker
LOG = logging.getLogger(__name__)
TIME_SINCE_LAST = 30
class TelemetryCallback(TrainerCallback):
"""
Trainer callback for tracking and reporting runtime metrics.
This callback tracks training progress, runtime, and memory usage,
sending telemetry at configurable intervals.
"""
report_interval_steps: int = 100
def __init__(self):
"""Initialize the metrics callback."""
self.tracker = RuntimeMetricsTracker()
self.telemetry_manager = TelemetryManager.get_instance()
self.current_epoch = -1
self.start_time = time.time()
self.last_report_time = None
self.last_report_step = 0
# pylint: disable=unused-argument
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle training start."""
self.telemetry_manager.send_event(event_type="train-start")
# pylint: disable=unused-argument
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle training end."""
# Send training completion event
self.telemetry_manager.send_event(
event_type="train-end",
properties=self._extract_last_metrics(state)
| self.tracker.metrics.to_dict(),
)
# pylint: disable=unused-argument
def on_epoch_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle epoch start."""
self.current_epoch += 1
self.tracker.start_epoch(self.current_epoch)
# pylint: disable=unused-argument
def on_epoch_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle epoch end."""
self.tracker.end_epoch(self.current_epoch)
# pylint: disable=unused-argument
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle step end."""
step = state.global_step
self.tracker.update_step(step)
# Check if we should report metrics
should_report = (
step % self.report_interval_steps == 0
or step == 1 # Always report first step
or step - self.last_report_step >= self.report_interval_steps
)
if should_report:
current_time = time.time()
if self.last_report_time is not None:
time_since_last_report = current_time - self.last_report_time
else:
time_since_last_report = current_time - self.start_time
steps_since_last_report = step - self.last_report_step
# Only report if enough time has passed
if (
step == 1
or time_since_last_report >= TIME_SINCE_LAST
or steps_since_last_report >= self.report_interval_steps
):
# Calculate steps per second for this interval
if time_since_last_report > 0 and steps_since_last_report > 0:
steps_per_second = steps_since_last_report / time_since_last_report
else:
steps_per_second = 0
# Update memory metrics
self.tracker.update_memory_metrics()
# Prepare metrics to report
metrics = self._extract_last_metrics(state) | {
"step": step,
"epoch": self.current_epoch,
"progress": state.epoch, # Fractional epoch progress
"steps_per_second": steps_per_second,
"elapsed_time": current_time - self.start_time,
"time_since_last_report": time_since_last_report,
}
# Add memory metrics
memory_metrics = self.tracker.get_memory_metrics()
metrics.update({"memory": memory_metrics})
# Send telemetry
self.telemetry_manager.send_event(
event_type="train-progress", properties=metrics
)
# Update last report time and step
self.last_report_time = current_time
self.last_report_step = step
def _extract_last_metrics(self, state: TrainerState) -> dict:
"""Extract last loss and learning_rate from log history."""
if not state.log_history:
return {"loss": 0, "learning_rate": 0}
last_log = state.log_history[-1]
return {
"loss": last_log.get("loss", 0),
"learning_rate": last_log.get("learning_rate", 0),
}

View File

@@ -0,0 +1,160 @@
"""Telemetry utilities for exception and traceback information."""
import logging
import os
import re
import traceback
from functools import wraps
from inspect import getmodule
from typing import Any, Callable
from axolotl.telemetry.manager import TelemetryManager
LOG = logging.getLogger(__name__)
ERROR_HANDLED = False
def sanitize_stack_trace(stack_trace: str) -> str:
"""
Remove personal information from stack trace messages while keeping Python package codepaths.
This function identifies Python packages by looking for common patterns in virtual environment
and site-packages directories, preserving the package path while removing user-specific paths.
Args:
stack_trace: The original stack trace string.
Returns:
A sanitized version of the stack trace with Python package paths preserved.
"""
# Split the stack trace into lines to process each file path separately
lines = stack_trace.split("\n")
sanitized_lines = []
# Regular expression to find file paths in the stack trace
path_pattern = re.compile(r'(?:File ")(.*?)(?:")')
# Regular expression to identify paths in site-packages or dist-packages
# This matches path segments like "site-packages/package_name" or "dist-packages/package_name"
site_packages_pattern = re.compile(
r"(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
)
# Additional common virtual environment patterns
venv_lib_pattern = re.compile(
r"(?:lib|Lib)[/\\](?:python\d+(?:\.\d+)?[/\\])?(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
)
for line in lines:
# Check if this line contains a file path
path_match = path_pattern.search(line)
if path_match:
full_path = path_match.group(1)
sanitized_path = ""
# Try to match site-packages pattern
site_packages_match = site_packages_pattern.search(full_path)
venv_lib_match = venv_lib_pattern.search(full_path)
if site_packages_match:
# Find the index where the matched pattern starts
idx = full_path.find("site-packages")
if idx == -1:
idx = full_path.find("dist-packages")
# Keep from 'site-packages' onward
if idx >= 0:
sanitized_path = full_path[idx:]
elif venv_lib_match:
# For other virtual environment patterns, find the package directory
match_idx = venv_lib_match.start(1)
if match_idx > 0:
# Keep from the package name onward
package_name = venv_lib_match.group(1)
idx = full_path.rfind(
package_name, 0, match_idx + len(package_name)
)
if idx >= 0:
sanitized_path = full_path[idx:]
# If we couldn't identify a package pattern but path contains 'axolotl'
elif "axolotl" in full_path:
idx = full_path.rfind("axolotl")
if idx >= 0:
sanitized_path = full_path[idx:]
# Apply the sanitization to the line
if sanitized_path:
line = line.replace(full_path, sanitized_path)
else:
# If we couldn't identify a package pattern, just keep the filename
filename = os.path.basename(full_path)
if filename:
line = line.replace(full_path, filename)
else:
line = line.replace(full_path, "")
sanitized_lines.append(line)
return "\n".join(sanitized_lines)
def send_errors(func: Callable) -> Callable:
"""
Decorator to send exception info in a function. If an exception is raised, we send
telemetry containing the stack trace and error message.
If an error occurs in a decorated function that is called by another decorated
function, we'll only send telemetry corresponding to the lower-level function.
Args:
func: Function to decorate.
Returns:
Decorated function.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> Any:
telemetry_manager = TelemetryManager.get_instance()
if not telemetry_manager.enabled:
return func(*args, **kwargs)
try:
return func(*args, **kwargs)
except Exception as exception:
# Only track if we're not already handling an error. This prevents us from
# capturing an error more than once in nested decorated function calls.
global ERROR_HANDLED # pylint: disable=global-statement
if not ERROR_HANDLED:
ERROR_HANDLED = True
# Get function module path
module = getmodule(func)
module_path = (
f"{module.__name__}.{func.__name__}" if module else func.__name__
)
# Get stack trace
stack_trace = "".join(
traceback.format_exception(
type(exception), exception, exception.__traceback__
)
)
stack_trace = sanitize_stack_trace(stack_trace)
# Send error telemetry
telemetry_manager.send_event(
event_type=f"{module_path}-error",
properties={
"exception": str(exception),
"stack_trace": stack_trace,
},
)
raise
return wrapper

View File

@@ -0,0 +1,417 @@
"""Telemetry manager and associated utilities."""
import atexit
import importlib
import logging
import os
import platform
import time
import uuid
from pathlib import Path
from typing import Any
import posthog
import psutil
import torch
import yaml
LOG = logging.getLogger(__name__)
POSTHOG_HOST = "https://app.posthog.com"
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
OPT_IN_WARNING_SLEEP_SECONDS = 10
OPT_IN_WARNING = (
"\nTelemetry is currently disabled by default. If you'd like to help improve "
"Axolotl, consider enabling it by setting AXOLOTL_DO_NOT_TRACK=0 in your environment.\n\n"
"Telemetry data helps us understand:\n"
"- Which features are most used\n"
"- What hardware configurations to prioritize\n"
"- Where users encounter errors\n\n"
"Personally identifiable information (PII) is not collected.\n\n"
"To remove this warning, explicitly set AXOLOTL_DO_NOT_TRACK=0 (enable telemetry) "
"or AXOLOTL_DO_NOT_TRACK=1 (explicitly disable telemetry).\n\n"
"Note: Telemetry will move to an opt-out in a later release.\n\n"
"For details, see: https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html\n\n"
f"Sleeping for {OPT_IN_WARNING_SLEEP_SECONDS}s..."
)
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
# NOTE: Need to keep these up to date with any config schema changes
FIELDS_TO_REDACT = {
"base_model",
"tokenizer_config",
"base_model_config",
"pretraining_dataset", # NOTE: this field may be a string or a dictionary
"resume_from_checkpoint",
"hub_model_id",
}
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"}
PATH_INDICATORS = {"path", "dir"}
# pylint: disable=duplicate-code
RELEVANT_PACKAGES = {
"torch",
"transformers",
"trl",
"datasets",
"peft",
"bitsandbytes",
"accelerate",
"optimum",
"deepspeed",
"ray",
"axolotl",
"triton",
"mamba-ssm",
"flash-attn",
"xformers",
"autoawq",
"tokenizers",
"sentencepiece",
"torchao",
"lm_eval",
}
def is_main_process() -> bool:
"""
Check whether we're running in the main process.
Note:
We're using this function instead of `torch.utils.distributed.is_main_process`
causes issues with DeepSpeed world_size since. This function avoids that issue
by checking env vars that are set by various launchers.
Returns:
Whether we're running in the main process.
"""
# If PyTorch distributed is already initialized, use it
if torch.distributed.is_initialized():
return torch.distributed.get_rank() == 0
# Otherwise check environment variables for global rank
# NOTE: need to verify this in SLURM / OpenMPI environments
global_rank = int(
os.environ.get(
"RANK",
os.environ.get(
"GLOBAL_RANK",
os.environ.get(
"SLURM_PROCID",
os.environ.get(
"OMPI_COMM_WORLD_RANK",
"0",
),
),
),
)
)
return global_rank == 0
class TelemetryManager:
"""Manages telemetry collection and transmission"""
_instance = None
_initialized = False
def __new__(cls):
"""
Telemetry manager constructor. Creates the singleton instance of this class if
it doesn't already exist.
"""
if cls._instance is None:
cls._instance = super(TelemetryManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
"""Telemetry manager initializer"""
if self._initialized:
return
self.enabled = self._check_telemetry_enabled()
if self.enabled:
self.run_id = str(uuid.uuid4())
self.whitelist = self._load_whitelist()
try:
self.system_info = self._get_system_info()
except Exception as e: # pylint: disable=broad-exception-caught
LOG.warning(f"Error during system info collection: {e}")
self.system_info = None
self._init_posthog()
# Register shutdown method to flush posthog telemetry
atexit.register(self.shutdown)
self._initialized = True
@classmethod
def get_instance(cls) -> "TelemetryManager":
if cls._instance is None:
cls._instance = TelemetryManager()
return cls._instance
def _check_telemetry_enabled(self) -> bool:
"""
Check if telemetry is enabled based on environment variables. We also check
whether this is the main process (for the distributed setting and to avoid
sending duplicate PostHog events per GPU).
Note: This is disabled by default on an opt-in basis. Set
`AXOLOTL_DO_NOT_TRACK=0` to enable telemetry. We plan to move to an opt-out
model in a later release. For more details, see
https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
Returns:
Boolean denoting whether telemetry is enabled or not.
"""
# Parse relevant env vars
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
do_not_track = os.getenv("DO_NOT_TRACK")
# Default to disabled (opt-in model for initial release)
if axolotl_do_not_track is None or axolotl_do_not_track.lower() not in (
"0",
"1",
"false",
"true",
):
# Print opt-in info message for main process only
if is_main_process():
LOG.warning(OPT_IN_WARNING)
time.sleep(OPT_IN_WARNING_SLEEP_SECONDS)
return False
# Only rank 0 will send telemetry
if not is_main_process():
return False
if do_not_track is None:
do_not_track = "0"
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
enabled = axolotl_do_not_track.lower() not in (
"1",
"true",
) and do_not_track.lower() not in ("1", "true")
return enabled
def _load_whitelist(self) -> dict:
"""Load HuggingFace Hub organization whitelist"""
with open(WHITELIST_PATH, encoding="utf-8") as f:
whitelist = yaml.safe_load(f)
# Send org strings to lowercase since model names are case insensitive
whitelist["organizations"] = {
org.lower() for org in whitelist["organizations"]
}
return whitelist
def _is_whitelisted(self, value: str) -> bool:
"""
Check if model / dataset / etc. org is in whitelist.
Args:
value: Value for one of `axolotl.telemetry.manager.FIELDS_WITH_ORGS`
("base_model", etc.).
Returns:
Boolean indicating whitelist membership.
"""
# NOTE: This membership-checking logic can be improved.
# What happens when a local model path matches a whitelisted org?
parts = value.split("/")
if len(parts) < 2:
return False
org = parts[0]
whitelisted = org.lower() in self.whitelist["organizations"]
return whitelisted
def _init_posthog(self):
"""Initialize PostHog client"""
posthog.host = POSTHOG_HOST
posthog.project_api_key = POSTHOG_WRITE_KEY
def _redact_paths(self, properties: dict[str, Any]) -> dict[str, Any]:
"""
Redact properties to remove any paths, so as to avoid inadvertently collecting
private or personally identifiable information (PII). We also remove
information related to Wandb, MLflow, etc. configuration.
Args:
properties: Dictionary of properties to redact.
Returns:
Properties dictionary with redaction applied.
"""
if not properties:
return {}
def redact_value(value: Any, key: str = "") -> Any:
"""Recursively sanitize values, redacting those with path-like keys"""
if isinstance(key, str) and isinstance(value, str):
# Other redaction special cases
if (
key in FIELDS_TO_REDACT
or any(prefix in key for prefix in PREFIXES_TO_REDACT)
or any(indicator in key.lower() for indicator in PATH_INDICATORS)
):
# Fields with whitelisted orgs don't need to be redacted
if not self._is_whitelisted(value):
return "[REDACTED]"
# Handle nested values
if isinstance(value, dict):
return {k: redact_value(v, k) for k, v in value.items()}
if isinstance(value, list):
return [redact_value(item) for item in value]
return value
# Create new dict with redacted values
redacted = {k: redact_value(v, k) for k, v in properties.items()}
return redacted
def _get_system_info(self) -> dict[str, Any]:
"""Collect system information for various hardware accelerators"""
gpu_info = []
accelerator_type = "none"
# NVIDIA GPUs
if torch.cuda.is_available():
accelerator_type = "cuda"
for i in range(torch.cuda.device_count()):
gpu_info.append(
{
"name": torch.cuda.get_device_name(i),
"memory": torch.cuda.get_device_properties(i).total_memory,
}
)
# AMD GPUs
elif hasattr(torch, "hip") and torch.hip.is_available():
accelerator_type = "hip"
for i in range(torch.hip.device_count()):
gpu_info.append(
{
"name": torch.hip.get_device_name(i),
"memory": (
torch.hip.get_device_properties(i).total_memory
if hasattr(torch.hip, "get_device_properties")
else None
),
}
)
# Apple Silicon
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
accelerator_type = "mps"
gpu_info.append(
{
"name": "Apple Silicon",
# NOTE: this is memory allocated to this process, not total memory
"memory": torch.mps.driver_allocated_memory(),
}
)
# Intel GPUs
elif hasattr(torch, "xpu") and torch.xpu.is_available():
accelerator_type = "xpu"
for i in range(torch.xpu.device_count()):
memory = None
if hasattr(torch.xpu, "get_device_properties"):
memory = torch.xpu.get_device_properties(i).total_memory
gpu_info.append(
{
"name": torch.xpu.get_device_name(i),
"memory": memory,
}
)
# NPUs
elif hasattr(torch, "npu") and torch.npu.is_available():
accelerator_type = "npu"
for i in range(torch.npu.device_count()):
memory = None
if hasattr(torch.npu, "get_device_properties"):
memory = torch.npu.get_device_properties(i).total_memory
gpu_info.append(
{
"name": torch.npu.get_device_name(i),
"memory": memory,
}
)
# Get relevant package versions
installed_packages = {}
for package in RELEVANT_PACKAGES:
try:
version = importlib.metadata.version(package)
installed_packages[f"{package}_version"] = version
except importlib.metadata.PackageNotFoundError:
pass
return {
"os": platform.system(),
"python_version": platform.python_version(),
"cpu_count": psutil.cpu_count(),
"memory_total": psutil.virtual_memory().total,
"accelerator_type": accelerator_type,
"accelerator_count": len(gpu_info),
"accelerator_info": gpu_info,
**installed_packages,
}
def send_event(self, event_type: str, properties: dict[str, Any] | None = None):
"""Send a telemetry event"""
if not self.enabled:
return
if properties is None:
properties = {}
# Sanitize properties to remove PII
properties = self._redact_paths(properties)
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage
try:
# Send event via PostHog
posthog.capture(
distinct_id=self.run_id,
event=event_type,
properties=properties,
disable_geoip=True,
)
except Exception as e: # pylint: disable=broad-exception-caught
LOG.warning(f"Failed to send telemetry event: {e}")
# Additionally, send system info telemetry when loading config.
# NOTE: Is this the best place for this?
if event_type == "config-loaded":
self.send_system_info()
def send_system_info(self):
"""Helper method for sending system info"""
if self.system_info is not None:
self.send_event(event_type="system-info", properties=self.system_info)
def shutdown(self):
"""Ensure all queued events are processed before shutdown"""
if self.enabled:
posthog.flush()

View File

@@ -0,0 +1,209 @@
"""Telemetry utilities for runtime and memory metrics."""
import logging
import time
from dataclasses import dataclass, field
from typing import Any
import psutil
import torch
from axolotl.telemetry.manager import TelemetryManager
LOG = logging.getLogger(__name__)
@dataclass
class RuntimeMetrics:
"""Container for runtime metrics to be tracked throughout training."""
# Timing metrics
start_time: float
epoch_start_times: dict[int, float] = field(init=False)
epoch_end_times: dict[int, float] = field(init=False)
# Memory metrics
peak_cpu_memory: int = 0
peak_gpu_memory: dict[int, int] = field(init=False)
# Progress metrics
total_steps: int = 0
current_epoch: int = 0
current_step: int = 0
def __post_init__(self):
"""Initialize empty metric mappings."""
self.epoch_start_times = {}
self.epoch_end_times = {}
self.peak_gpu_memory = {}
@property
def elapsed_time(self) -> float:
"""Calculate total elapsed time in seconds."""
return time.time() - self.start_time
def epoch_time(self, epoch: int) -> float | None:
"""Calculate time taken for a specific epoch in seconds."""
if epoch in self.epoch_start_times and epoch in self.epoch_end_times:
return self.epoch_end_times[epoch] - self.epoch_start_times[epoch]
return None
def average_epoch_time(self) -> float | None:
"""Calculate average time per epoch in seconds."""
completed_epochs = [
epoch for epoch in self.epoch_start_times if epoch in self.epoch_end_times
]
if not completed_epochs:
return None
total_time = 0.0
for epoch in completed_epochs:
epoch_time = self.epoch_time(epoch)
if epoch_time is not None: # Check to avoid mypy warning
total_time += epoch_time
return total_time / len(completed_epochs)
def steps_per_second(self) -> float | None:
"""Calculate average steps per second across all training."""
if self.total_steps == 0 or self.elapsed_time == 0:
return None
return self.total_steps / self.elapsed_time
def to_dict(self) -> dict[str, Any]:
"""Convert metrics to a dictionary for telemetry reporting."""
metrics = {
"total_time_seconds": self.elapsed_time,
"total_steps": self.total_steps,
"steps_per_second": self.steps_per_second(),
"epochs_completed": len(
[
epoch
for epoch in self.epoch_start_times
if epoch in self.epoch_end_times
]
),
"peak_cpu_memory_bytes": self.peak_cpu_memory,
}
# Add per-epoch timing if available
epoch_times: dict[str, float] = {}
for epoch in sorted(self.epoch_end_times.keys()):
time_taken = self.epoch_time(epoch)
if time_taken is not None:
epoch_times[f"epoch_{epoch}_seconds"] = time_taken
if epoch_times:
metrics["epoch_times"] = epoch_times # type: ignore
metrics["average_epoch_time_seconds"] = self.average_epoch_time()
# Add GPU memory metrics if available
if self.peak_gpu_memory:
gpu_metrics: dict[str, int] = {}
for gpu_id, memory in self.peak_gpu_memory.items():
gpu_metrics[f"gpu_{gpu_id}_peak_memory_bytes"] = memory
metrics["gpu_memory"] = gpu_metrics # type: ignore
return metrics
class RuntimeMetricsTracker:
"""Tracker for runtime metrics during training."""
update_interval = 100
def __init__(self):
"""Initialize the runtime metrics tracker."""
self.metrics = RuntimeMetrics(start_time=time.time())
self.telemetry_manager = TelemetryManager.get_instance()
def start_epoch(self, epoch: int):
"""Record the start of a new epoch."""
self.metrics.current_epoch = epoch
self.metrics.epoch_start_times[epoch] = time.time()
self.update_memory_metrics()
def end_epoch(self, epoch: int):
"""Record the end of an epoch."""
self.metrics.epoch_end_times[epoch] = time.time()
def update_step(self, step: int):
"""Update the current step count."""
self.metrics.current_step = step
self.metrics.total_steps += 1
# Periodically update memory metrics
if step % self.update_interval == 0:
self.update_memory_metrics()
def _get_allocated_memory(self) -> dict[int, int]:
"""
Helper function for getting accelerator-agnostic allocated memory.
Returns:
A dictionary mapping device IDs to allocated memory in bytes
"""
memory_used: dict[int, int] = {}
# NVIDIA GPUs
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
memory_used[i] = torch.cuda.memory_allocated(i)
# AMD GPUs
elif hasattr(torch, "hip") and torch.hip.is_available():
for i in range(torch.hip.device_count()):
if hasattr(torch.hip, "memory_allocated"):
memory_used[i] = torch.hip.memory_allocated(i)
# Apple Silicon
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
# MPS doesn't have per-device memory stats since there's only one device
if hasattr(torch.mps, "current_allocated_memory"):
memory_used[0] = torch.mps.current_allocated_memory()
# Intel GPUs
elif hasattr(torch, "xpu") and torch.xpu.is_available():
for i in range(torch.xpu.device_count()):
if hasattr(torch.xpu, "memory_allocated"):
memory_used[i] = torch.xpu.memory_allocated(i)
# NPUs
elif hasattr(torch, "npu") and torch.npu.is_available():
for i in range(torch.npu.device_count()):
if hasattr(torch.npu, "memory_allocated"):
memory_used[i] = torch.npu.memory_allocated(i)
return memory_used
def update_memory_metrics(self):
"""Update peak memory usage metrics."""
# CPU memory
cpu_memory = psutil.Process().memory_info().rss
self.metrics.peak_cpu_memory = max(self.metrics.peak_cpu_memory, cpu_memory)
# GPU memory (if available)
memory_used = self._get_allocated_memory()
for i, memory in memory_used.items():
self.metrics.peak_gpu_memory[i] = max(
self.metrics.peak_gpu_memory.get(i, 0), memory
)
def get_memory_metrics(self) -> dict[str, Any]:
"""Get the current memory metrics as a dictionary."""
memory_metrics = {
"cpu_memory_bytes": psutil.Process().memory_info().rss,
"peak_cpu_memory_bytes": self.metrics.peak_cpu_memory,
}
# GPU memory (if available)
memory_used = self._get_allocated_memory()
for i, memory in memory_used.items():
memory_metrics[f"gpu_{i}_memory_bytes"] = memory
memory_metrics[f"gpu_{i}_peak_memory_bytes"] = (
self.metrics.peak_gpu_memory.get(i, 0)
)
return memory_metrics

View File

@@ -0,0 +1,17 @@
organizations:
- "axolotl-ai-co"
- "meta-llama"
- "huggingface"
- "nvidia"
- "facebook"
- "google"
- "microsoft"
- "deepseek-ai"
- "HuggingFaceTB"
- "mistralai"
- "Qwen"
- "unsloth"
- "NousResearch"
- "allenai"
- "amd"
- "tiiuae"

View File

@@ -1,13 +1,10 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
from __future__ import annotations
import importlib
import inspect
import os
import signal
import sys
import typing
import weakref
from contextlib import ExitStack
from pathlib import Path
@@ -28,12 +25,15 @@ from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.integrations.base import PluginManager
from axolotl.loaders import (
ModelLoader,
load_processor,
load_tokenizer,
)
from axolotl.telemetry.errors import send_errors
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
@@ -47,11 +47,11 @@ try:
except ImportError:
BetterTransformer = None
if typing.TYPE_CHECKING:
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
LOG = get_logger(__name__)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
PLUGIN_MANAGER = PluginManager.get_instance()
def setup_model_and_tokenizer(
cfg: DictDefault,
@@ -69,7 +69,10 @@ def setup_model_and_tokenizer(
`None`), and processor (if multimodal, else `None`).
"""
# Load tokenizer
LOG.debug(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
# Load processor for multimodal models if needed
@@ -88,6 +91,14 @@ def setup_model_and_tokenizer(
if model.generation_config is not None:
model.generation_config.do_sample = True
TELEMETRY_MANAGER.send_event(
event_type="model-load", properties=model.config.to_dict()
)
if peft_config:
TELEMETRY_MANAGER.send_event(
event_type="peft-config-load", properties=peft_config.to_dict()
)
# Apply freezing if specified
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
@@ -477,7 +488,7 @@ def handle_untrained_tokens_fix(
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
"HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
HFRLTrainerBuilder | HFCausalTrainerBuilder,
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,
@@ -522,6 +533,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
model_ref=model_ref,
peft_config=peft_config,
)
PLUGIN_MANAGER.post_trainer_create(cfg, trainer)
return (
trainer,
@@ -532,6 +544,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
)
@send_errors
def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
@@ -556,8 +569,11 @@ def train(
processor,
) = setup_model_and_trainer(cfg, dataset_meta)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_trainer_create(cfg, trainer)
# Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Configuration for saving
safe_serialization = cfg.save_safetensors is True
# Handle untrained tokens if configured
safe_serialization = cfg.save_safetensors is True
@@ -572,7 +588,6 @@ def train(
setup_model_card(cfg)
# Execute the training
resume_from_checkpoint = determine_resume_checkpoint(cfg)
execute_training(cfg, trainer, resume_from_checkpoint)
# Save the trained model and cleanup
@@ -580,7 +595,6 @@ def train(
create_model_card(cfg, trainer)
if not cfg.use_ray:
cleanup_distributed()
plugin_manager.post_train(cfg, model)
PLUGIN_MANAGER.post_train(cfg, model)
return model, tokenizer, trainer

View File

@@ -52,10 +52,3 @@ def patch_optimized_env():
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
set_pytorch_cuda_alloc_conf()
def get_not_null(value, default=None):
"""
return the value if it's not None, otherwise return the default value
"""
return value if value is not None else default

File diff suppressed because one or more lines are too long

View File

@@ -1,7 +1,7 @@
"""Data collators for axolotl to pad labels and position_ids for packed sequences"""
from dataclasses import dataclass
from typing import Any, List
from typing import Any
import numpy as np
from transformers import PreTrainedTokenizerBase
@@ -161,7 +161,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
if not isinstance(features[0], list):
features: List[List[dict]] = [features]
features = [features]
out_features = [{} for _ in features]
for i, features_ in enumerate(features):
for feature in features_[0].keys():

View File

@@ -40,7 +40,6 @@ def retry_on_request_exceptions(
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
requests.exceptions.HTTPError,
huggingface_hub.errors.HfHubHTTPError,
) as exc:
if attempt < max_retries - 1:

View File

@@ -258,7 +258,7 @@ class MultipackBatchSampler(BatchSampler):
batch_max_len: int, # Maximum sequence length (bin capacity)
lengths: np.ndarray, # Sequence lengths
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
drop_last: bool = False, # Whether to drop final batches (might be incomplete)
num_count_samples: int = 16, # Number of times to estimate batch count
sequential: bool = False, # Whether to use sequential packing
group_size: int = 100_000, # Size of groups for parallel packing
@@ -443,18 +443,10 @@ class MultipackBatchSampler(BatchSampler):
if self._len_across_ranks is None:
# Sample multiple times to get stable estimate
_sampled_lens = []
for _ in range(self.num_count_samples):
self._batches = None # Reset cached batches
_sampled_lens.append(len(self.generate_batches(set_stats=False)))
len_batches = min(_sampled_lens)
len_batches = min( # pylint: disable=consider-using-generator
[len(self._batches) for _ in range(self.num_count_samples)]
)
# Gather minimum across all ranks
if self._len_across_ranks is None:
self._len_across_ranks = self.gather_len_batches(len_batches)
else:
self._len_across_ranks = min(
self._len_across_ranks, self.gather_len_batches(len_batches)
)
self._len_across_ranks = self.gather_len_batches(len_batches)
return self._len_across_ranks

View File

@@ -1259,7 +1259,7 @@ class AxolotlInputConfig(
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""
"""Wrapper to validate GPU capabilities with the config options"""
capabilities: GPUCapabilities
env_capabilities: EnvCapabilities

View File

@@ -16,6 +16,7 @@ from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
@@ -481,9 +482,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
)
)
if cfg.dataloader_drop_last:
# drop the last batch for each epoch
total_num_steps -= int(math.ceil(cfg.num_epochs))
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
@@ -631,8 +629,6 @@ def setup_trainer(
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters.
"""
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
if (
cfg.torch_compile
and cfg.fsdp_config

View File

@@ -1,6 +1,4 @@
"""
shared pytest fixtures
"""
"""Shared pytest fixtures"""
import functools
import importlib
@@ -559,3 +557,9 @@ def test_load_fixtures(
download_llama2_model_fixture,
):
pass
@pytest.fixture(autouse=True)
def disable_telemetry(monkeypatch):
monkeypatch.setenv("AXOLOTL_DO_NOT_TRACK", "1")
yield

View File

View File

@@ -0,0 +1,9 @@
"""Shared pytest fixtures for telemetry tests."""
import pytest
@pytest.fixture(autouse=True)
def disable_telemetry(monkeypatch):
monkeypatch.delenv("AXOLOTL_DO_NOT_TRACK", raising=False)
yield

View File

@@ -0,0 +1,373 @@
"""Tests for telemetry callback module."""
# pylint: disable=redefined-outer-name
import time
from unittest.mock import MagicMock, patch
import pytest
from transformers import TrainerControl, TrainerState, TrainingArguments
from axolotl.telemetry.callbacks import TIME_SINCE_LAST, TelemetryCallback
def calc_expected_metrics(step, last_step, current_time, last_time, start_time=900.0):
"""Calculate expected metrics values for tests"""
time_diff = current_time - last_time
step_diff = step - last_step
return {
"steps_per_second": (
step_diff / time_diff if time_diff > 0 and step_diff > 0 else 0
),
"time_since_last_report": time_diff,
"elapsed_time": current_time - start_time,
}
@pytest.fixture
def mock_time():
"""Mock time.time() to have predictable values in tests"""
with patch("axolotl.telemetry.callbacks.time") as mock_time:
mock_time.time.return_value = 1000.0
yield mock_time
@pytest.fixture
def mock_telemetry_manager():
"""Create a mock TelemetryManager"""
with patch("axolotl.telemetry.callbacks.TelemetryManager") as mock_manager_class:
mock_manager = MagicMock()
mock_manager_class.get_instance.return_value = mock_manager
yield mock_manager
@pytest.fixture
def mock_runtime_metrics_tracker():
"""Create a mock RuntimeMetricsTracker"""
with patch(
"axolotl.telemetry.callbacks.RuntimeMetricsTracker"
) as mock_tracker_class:
mock_tracker = MagicMock()
# Set up metrics property on the tracker
mock_metrics = MagicMock()
mock_metrics.to_dict.return_value = {
"total_steps": 100,
"peak_cpu_memory_bytes": 1024,
}
mock_tracker.metrics = mock_metrics
# Make the constructor return our mock
mock_tracker_class.return_value = mock_tracker
yield mock_tracker
@pytest.fixture
def training_args():
"""Create a minimal TrainingArguments instance"""
return TrainingArguments(output_dir="./output")
@pytest.fixture
def trainer_state():
"""Create a mock TrainerState"""
state = MagicMock(spec=TrainerState)
state.global_step = 10
state.epoch = 0.5 # halfway through first epoch
state.log_history = [{"loss": 2.5, "learning_rate": 5e-5}]
return state
@pytest.fixture
def trainer_control():
"""Create a mock TrainerControl"""
return MagicMock(spec=TrainerControl)
# pylint: disable=unused-argument
@pytest.fixture
def callback(mock_telemetry_manager, mock_runtime_metrics_tracker):
"""Create a TelemetryCallback instance with mocked dependencies"""
return TelemetryCallback()
class TestTelemetryCallback:
"""Tests for the TelemetryCallback class."""
def test_initialization(self, callback, mock_runtime_metrics_tracker):
"""Test callback initialization."""
assert callback.current_epoch == -1
assert callback.tracker == mock_runtime_metrics_tracker
assert callback.last_report_step == 0
assert hasattr(callback, "start_time")
assert hasattr(callback, "last_report_time")
assert callback.report_interval_steps == 100
def test_on_train_begin(
self,
callback,
mock_telemetry_manager,
training_args,
trainer_state,
trainer_control,
):
"""Test on_train_begin sends expected event."""
callback.on_train_begin(training_args, trainer_state, trainer_control)
mock_telemetry_manager.send_event.assert_called_once_with(
event_type="train-start"
)
def test_on_train_end(
self,
callback,
mock_telemetry_manager,
training_args,
trainer_state,
trainer_control,
):
"""Test on_train_end sends expected event with metrics."""
callback.on_train_end(training_args, trainer_state, trainer_control)
mock_telemetry_manager.send_event.assert_called_once()
call_args = mock_telemetry_manager.send_event.call_args[1]
assert call_args["event_type"] == "train-end"
assert "loss" in call_args["properties"]
assert call_args["properties"]["loss"] == 2.5
assert "learning_rate" in call_args["properties"]
assert call_args["properties"]["learning_rate"] == 5e-5
# Check that metrics from RuntimeMetricsTracker are included
assert "total_steps" in call_args["properties"]
assert call_args["properties"]["total_steps"] == 100
assert "peak_cpu_memory_bytes" in call_args["properties"]
assert call_args["properties"]["peak_cpu_memory_bytes"] == 1024
def test_on_epoch_begin(
self,
callback,
mock_runtime_metrics_tracker,
training_args,
trainer_state,
trainer_control,
):
"""Test on_epoch_begin updates epoch counter and calls tracker."""
initial_epoch = callback.current_epoch
callback.on_epoch_begin(training_args, trainer_state, trainer_control)
assert callback.current_epoch == initial_epoch + 1
mock_runtime_metrics_tracker.start_epoch.assert_called_once_with(
initial_epoch + 1
)
def test_on_epoch_end(
self,
callback,
mock_runtime_metrics_tracker,
training_args,
trainer_state,
trainer_control,
):
"""Test on_epoch_end calls tracker."""
# Set current epoch
callback.current_epoch = 2
callback.on_epoch_end(training_args, trainer_state, trainer_control)
mock_runtime_metrics_tracker.end_epoch.assert_called_once_with(2)
def test_on_step_end_no_report(
self,
callback,
mock_telemetry_manager,
mock_runtime_metrics_tracker,
training_args,
trainer_state,
trainer_control,
):
"""Test on_step_end updates tracker but doesn't report if criteria not met."""
# Set up state to avoid reporting
trainer_state.global_step = 42 # Not divisible by report_interval_steps
callback.last_report_step = 41 # Just 1 step since last report
callback.last_report_time = time.time() # Just now
callback.on_step_end(training_args, trainer_state, trainer_control)
# Should update tracker
mock_runtime_metrics_tracker.update_step.assert_called_once_with(42)
# Should not send telemetry
mock_telemetry_manager.send_event.assert_not_called()
# Should not update last report time/step
assert callback.last_report_step == 41
def test_on_step_end_report_interval_steps(
self,
callback,
mock_telemetry_manager,
mock_runtime_metrics_tracker,
mock_time,
training_args,
trainer_state,
trainer_control,
):
"""Test on_step_end reports when step interval is reached."""
# Set up state with clear values
current_step = 100 # Exactly matches report_interval_steps
last_step = 0
start_time = 900.0
current_time = 1000.0
time_diff = current_time - start_time # 100 seconds
# Configure state and callback
trainer_state.global_step = current_step
callback.report_interval_steps = 100
callback.last_report_step = last_step
callback.start_time = start_time
callback.last_report_time = start_time
# Mock time.time() to return consistent values
mock_time.time.return_value = current_time
callback.on_step_end(training_args, trainer_state, trainer_control)
# Should update tracker
mock_runtime_metrics_tracker.update_step.assert_called_once_with(current_step)
mock_runtime_metrics_tracker.update_memory_metrics.assert_called_once()
# Should send telemetry
mock_telemetry_manager.send_event.assert_called_once()
call_args = mock_telemetry_manager.send_event.call_args[1]
assert call_args["event_type"] == "train-progress"
# Properties should include expected values
props = call_args["properties"]
assert props["step"] == current_step
assert props["elapsed_time"] == time_diff # 1000 - 900 = 100
assert props["time_since_last_report"] == time_diff # 1000 - 900 = 100
assert props["steps_per_second"] == 1.0 # 100 steps / 100 seconds
# Should update last report time/step
assert callback.last_report_step == current_step
assert callback.last_report_time == current_time
def test_on_step_end_report_time_elapsed(
self,
callback,
mock_telemetry_manager,
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
mock_time,
training_args,
trainer_state,
trainer_control,
):
"""Test on_step_end reports when enough time has elapsed."""
# Set up state with clear values
current_step = 120
last_step = 10
start_time = 900.0
current_time = 1000.0
time_diff = TIME_SINCE_LAST + 1 # Just over the threshold
# Configure state and callback
trainer_state.global_step = current_step
callback.report_interval_steps = 100
callback.last_report_step = last_step
callback.start_time = start_time
callback.last_report_time = current_time - time_diff
# Mock time.time() to return consistent values
mock_time.time.return_value = current_time
callback.on_step_end(training_args, trainer_state, trainer_control)
# Should send telemetry
mock_telemetry_manager.send_event.assert_called_once()
# Properties should include expected values
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
expected_metrics = calc_expected_metrics(
current_step, last_step, current_time, current_time - time_diff, start_time
)
assert props["steps_per_second"] == expected_metrics["steps_per_second"]
assert (
props["time_since_last_report"]
== expected_metrics["time_since_last_report"]
)
def test_on_step_end_first_step(
self,
callback,
mock_telemetry_manager,
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
mock_time,
training_args,
trainer_state,
trainer_control,
):
"""Test on_step_end always reports on first step."""
# Set up state with clear values
current_step = 1 # First step
last_step = 0
start_time = 900.0
current_time = 1000.0
last_report_time = 999.0 # Just 1 second ago
# Configure state and callback
trainer_state.global_step = current_step
callback.report_interval_steps = 100
callback.last_report_step = last_step
callback.start_time = start_time
callback.last_report_time = last_report_time
# Mock time.time() to return consistent values
mock_time.time.return_value = current_time
callback.on_step_end(training_args, trainer_state, trainer_control)
# Should send telemetry even though not much time has passed
mock_telemetry_manager.send_event.assert_called_once()
# Properties should include expected values for first step
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
assert props["step"] == current_step
expected_metrics = calc_expected_metrics(
current_step, last_step, current_time, last_report_time, start_time
)
assert props["steps_per_second"] == expected_metrics["steps_per_second"]
def test_log_history_empty(
self,
callback,
mock_telemetry_manager,
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
mock_time,
training_args,
trainer_state,
trainer_control,
):
"""Test handling of empty log history."""
# Set up state with clear values
current_step = 1
start_time = 900.0
current_time = 1000.0
# Configure state and callback
trainer_state.global_step = current_step
trainer_state.log_history = []
callback.start_time = start_time
# Mock time.time() to return consistent values
mock_time.time.return_value = current_time
callback.on_step_end(training_args, trainer_state, trainer_control)
# Should still send telemetry
mock_telemetry_manager.send_event.assert_called_once()
# Properties should have default values for missing log data
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
assert props["loss"] == 0
assert props["learning_rate"] == 0

View File

@@ -0,0 +1,341 @@
"""Tests for telemetry error utilities"""
# pylint: disable=redefined-outer-name
from unittest.mock import MagicMock, patch
import pytest
from axolotl.telemetry.errors import sanitize_stack_trace, send_errors
@pytest.fixture(autouse=True)
def reset_error_flag(monkeypatch):
"""Reset ERROR_HANDLED flag using monkeypatch"""
import axolotl.telemetry.errors
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
yield
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
@pytest.fixture
def example_stack_trace():
"""Provide a sample stack trace with mixed paths"""
return """Traceback (most recent call last):
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
trainer = get_trainer(cfg)
File "/home/user/.local/lib/python3.9/site-packages/axolotl/train.py", line 214, in get_trainer
model = get_model(cfg, tokenizer)
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/models.py", line 120, in get_model
raise ValueError("Model path not found")
ValueError: Model path not found
"""
@pytest.fixture
def windows_stack_trace():
"""Provide a sample stack trace with Windows paths"""
return """Traceback (most recent call last):
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\cli\\train.py", line 83, in main
trainer = get_trainer(cfg)
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\train.py", line 214, in get_trainer
model = get_model(cfg, tokenizer)
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\auto\\modeling_auto.py", line 482, in from_pretrained
raise ValueError(f"Unrecognized configuration class {config.__class__}")
ValueError: Unrecognized configuration class <class 'transformers.models.llama.configuration_llama.LlamaConfig'>
"""
@pytest.fixture
def mixed_stack_trace():
"""Provide a sample stack trace with both axolotl and non-axolotl paths"""
return """Traceback (most recent call last):
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
trainer = get_trainer(cfg)
File "/home/user/.local/lib/python3.9/site-packages/transformers/trainer.py", line 520, in train
self._inner_training_loop()
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/trainer.py", line 75, in _inner_training_loop
super()._inner_training_loop()
File "/home/user/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
data = self._next_data()
RuntimeError: CUDA out of memory
"""
@pytest.fixture
def venv_stack_trace():
"""Provide a sample stack trace with virtual environment paths"""
return """Traceback (most recent call last):
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 1729, in train
self._inner_training_loop()
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 2013, in _inner_training_loop
self.accelerator.backward(loss)
File "/home/user/venv/lib/python3.9/site-packages/accelerate/accelerator.py", line 1851, in backward
self.scaler.scale(loss).backward(**kwargs)
File "/home/user/venv/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
RuntimeError: CUDA out of memory
"""
@pytest.fixture
def dist_packages_stack_trace():
"""Provide a sample stack trace with dist-packages paths"""
return """Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 631, in __next__
data = self._next_data()
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 675, in _next_data
data = self._dataset_fetcher.fetch(index)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py", line 2808, in __getitem__
raise IndexError(f"Index {key} out of range for dataset of length {len(self)}.")
IndexError: Index 10000 out of range for dataset of length 9832.
"""
@pytest.fixture
def project_stack_trace():
"""Provide a sample stack trace from a project directory (not a virtual env)"""
return """Traceback (most recent call last):
File "/home/user/projects/myproject/run.py", line 25, in <module>
main()
File "/home/user/projects/myproject/src/cli.py", line 45, in main
app.run()
File "/home/user/projects/myproject/src/app.py", line 102, in run
raise ValueError("Configuration missing")
ValueError: Configuration missing
"""
def test_sanitize_stack_trace(example_stack_trace):
"""Test that sanitize_stack_trace properly preserves axolotl paths"""
sanitized = sanitize_stack_trace(example_stack_trace)
# Check that personal paths are removed
assert "/home/user" not in sanitized
assert ".local/lib/python3.9" not in sanitized
# Check that site-packages is preserved
assert "site-packages/axolotl/cli/train.py" in sanitized
assert "site-packages/axolotl/train.py" in sanitized
assert "site-packages/axolotl/utils/models.py" in sanitized
# Check that error message is preserved
assert "ValueError: Model path not found" in sanitized
def test_sanitize_windows_paths(windows_stack_trace):
"""Test that sanitize_stack_trace handles Windows paths"""
sanitized = sanitize_stack_trace(windows_stack_trace)
# Check that personal paths are removed
assert "C:\\Users\\name" not in sanitized
assert "AppData\\Local\\Programs\\Python" not in sanitized
# Check that both axolotl and transformers packages are preserved
assert (
"site-packages\\axolotl\\cli\\train.py" in sanitized
or "site-packages/axolotl/cli/train.py" in sanitized
)
assert (
"site-packages\\axolotl\\train.py" in sanitized
or "site-packages/axolotl/train.py" in sanitized
)
assert (
"site-packages\\transformers\\models\\auto\\modeling_auto.py" in sanitized
or "site-packages/transformers/models/auto/modeling_auto.py" in sanitized
)
# Check that error message is preserved
assert "ValueError: Unrecognized configuration class" in sanitized
def test_sanitize_mixed_paths(mixed_stack_trace):
"""Test that sanitize_stack_trace preserves all package paths"""
sanitized = sanitize_stack_trace(mixed_stack_trace)
# Check that all package paths are preserved
assert "site-packages/axolotl/cli/train.py" in sanitized
assert "site-packages/transformers/trainer.py" in sanitized
assert "site-packages/axolotl/utils/trainer.py" in sanitized
assert "site-packages/torch/utils/data/dataloader.py" in sanitized
# Check that error message is preserved
assert "RuntimeError: CUDA out of memory" in sanitized
def test_sanitize_venv_paths(venv_stack_trace):
"""Test that sanitize_stack_trace preserves virtual environment package paths"""
sanitized = sanitize_stack_trace(venv_stack_trace)
# Check that personal paths are removed
assert "/home/user/venv" not in sanitized
# Check that all package paths are preserved
assert "site-packages/transformers/trainer.py" in sanitized
assert "site-packages/accelerate/accelerator.py" in sanitized
assert "site-packages/torch/_tensor.py" in sanitized
# Check that error message is preserved
assert "RuntimeError: CUDA out of memory" in sanitized
def test_sanitize_dist_packages(dist_packages_stack_trace):
"""Test that sanitize_stack_trace preserves dist-packages paths"""
sanitized = sanitize_stack_trace(dist_packages_stack_trace)
# Check that system paths are removed
assert "/usr/local/lib/python3.8" not in sanitized
# Check that all package paths are preserved
assert "dist-packages/torch/utils/data/dataloader.py" in sanitized
assert "dist-packages/torch/utils/data/_utils/fetch.py" in sanitized
assert "dist-packages/datasets/arrow_dataset.py" in sanitized
# Check that error message is preserved
assert (
"IndexError: Index 10000 out of range for dataset of length 9832." in sanitized
)
def test_sanitize_project_paths(project_stack_trace):
"""Test handling of project paths (non-virtual env)"""
sanitized = sanitize_stack_trace(project_stack_trace)
# Check that personal paths are removed
assert "/home/user/projects" not in sanitized
# For non-package paths, we should at least preserve the filename
assert "run.py" in sanitized
assert "cli.py" in sanitized
assert "app.py" in sanitized
# Check that error message is preserved
assert "ValueError: Configuration missing" in sanitized
@pytest.fixture
def mock_telemetry_manager():
"""Create a mock TelemetryManager"""
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
mock_manager = MagicMock()
mock_manager.enabled = True
mock_manager_class.get_instance.return_value = mock_manager
yield mock_manager
def test_send_errors_successful_execution(mock_telemetry_manager):
"""Test that send_errors doesn't send telemetry for successful function execution"""
@send_errors
def test_func():
return "success"
result = test_func()
assert result == "success"
mock_telemetry_manager.send_event.assert_not_called()
def test_send_errors_with_exception(mock_telemetry_manager):
"""Test that send_errors sends telemetry when an exception occurs"""
test_error = ValueError("Test error")
@send_errors
def test_func():
raise test_error
with pytest.raises(ValueError) as excinfo:
test_func()
assert excinfo.value == test_error
mock_telemetry_manager.send_event.assert_called_once()
# Check that the error info was passed correctly
call_args = mock_telemetry_manager.send_event.call_args[1]
assert "test_func-error" in call_args["event_type"]
assert "Test error" in call_args["properties"]["exception"]
assert "stack_trace" in call_args["properties"]
def test_send_errors_nested_calls(mock_telemetry_manager):
"""Test that send_errors only sends telemetry once for nested decorated functions"""
@send_errors
def inner_func():
raise ValueError("Inner error")
@send_errors
def outer_func():
return inner_func()
with pytest.raises(ValueError):
outer_func()
# Telemetry should be sent only once for the inner function
assert mock_telemetry_manager.send_event.call_count == 1
call_args = mock_telemetry_manager.send_event.call_args[1]
assert "inner_func-error" in call_args["event_type"]
def test_send_errors_telemetry_disable():
"""Test that send_errors doesn't attempt to send telemetry when disabled"""
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
mock_manager = MagicMock()
mock_manager.enabled = False
mock_manager_class.get_instance.return_value = mock_manager
@send_errors
def test_func():
raise ValueError("Test error")
with pytest.raises(ValueError):
test_func()
mock_manager.send_event.assert_not_called()
def test_error_handled_reset():
"""Test that ERROR_HANDLED flag is properly reset"""
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
# Create and configure the mock manager
mock_manager = MagicMock()
mock_manager.enabled = True
mock_manager_class.get_instance.return_value = mock_manager
from axolotl.telemetry.errors import ERROR_HANDLED
@send_errors
def test_func():
raise ValueError("Test error")
assert not ERROR_HANDLED
with pytest.raises(ValueError):
test_func()
from axolotl.telemetry.errors import ERROR_HANDLED
assert ERROR_HANDLED
def test_module_path_resolution(mock_telemetry_manager):
"""Test that the module path is correctly resolved for the event type"""
import inspect
current_module = inspect.getmodule(test_module_path_resolution).__name__
@send_errors
def test_func():
raise ValueError("Test error")
with pytest.raises(ValueError):
test_func()
assert mock_telemetry_manager.send_event.called
event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"]
expected_event_type = f"{current_module}.test_func-error"
assert expected_event_type == event_type

View File

@@ -0,0 +1,278 @@
"""Tests for TelemetryManager class and utilities"""
# pylint: disable=redefined-outer-name,protected-access
import os
from unittest.mock import patch
import pytest
import yaml
from axolotl.telemetry.manager import TelemetryManager
@pytest.fixture
def mock_whitelist(tmp_path):
"""Create a temporary whitelist file for testing"""
whitelist_content = {
"organizations": ["meta-llama", "mistralai"],
}
whitelist_file = tmp_path / "whitelist.yaml"
with open(whitelist_file, "w", encoding="utf-8") as f:
yaml.dump(whitelist_content, f)
return str(whitelist_file)
@pytest.fixture
def telemetry_manager_class():
"""Reset the TelemetryManager singleton between tests"""
original_instance = TelemetryManager._instance
original_initialized = TelemetryManager._initialized
TelemetryManager._instance = None
TelemetryManager._initialized = False
yield TelemetryManager
TelemetryManager._instance = original_instance
TelemetryManager._initialized = original_initialized
@pytest.fixture
def manager(telemetry_manager_class, mock_whitelist):
"""Create a TelemetryManager instance with mocked dependencies"""
with (
patch("posthog.capture"),
patch("posthog.flush"),
patch("time.sleep"),
patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist),
patch.dict(os.environ, {"RANK": "0"}),
):
manager = telemetry_manager_class()
# Manually enable for most tests
manager.enabled = True
return manager
def test_singleton_instance(telemetry_manager_class):
"""Test that TelemetryManager is a singleton"""
with (
patch("posthog.capture"),
patch("time.sleep"),
patch.dict(os.environ, {"RANK": "0"}),
):
first = telemetry_manager_class()
second = telemetry_manager_class()
assert first is second
assert telemetry_manager_class.get_instance() is first
def test_telemetry_disabled_by_default(telemetry_manager_class):
"""Test that telemetry is disabled by default (opt-in)"""
with (
patch.dict(os.environ, {"RANK": "0"}, clear=True),
patch("time.sleep"),
patch("logging.Logger.info"),
):
manager = telemetry_manager_class()
assert not manager.enabled
def test_telemetry_enabled_with_explicit_opt_in(telemetry_manager_class):
"""Test that telemetry is enabled when AXOLOTL_DO_NOT_TRACK=0"""
with (
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}),
patch("time.sleep"),
):
manager = telemetry_manager_class()
assert manager.enabled
def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
"""Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1"""
with (
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}),
patch("time.sleep"),
):
manager = telemetry_manager_class()
assert not manager.enabled
def test_telemetry_disabled_with_do_not_track(telemetry_manager_class):
"""Test that telemetry is disabled when DO_NOT_TRACK=1"""
with (
patch.dict(
os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "DO_NOT_TRACK": "1", "RANK": "0"}
),
patch("time.sleep"),
):
manager = telemetry_manager_class()
assert not manager.enabled
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
"""Test that telemetry is disabled for non-main processes"""
with (
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "1"}),
patch("time.sleep"),
):
manager = telemetry_manager_class()
assert not manager.enabled
def test_opt_in_info_displayed(telemetry_manager_class):
"""Test that opt-in info is displayed when telemetry is not configured"""
with (
patch.dict(os.environ, {"RANK": "0"}, clear=True),
patch("logging.Logger.warning") as mock_warning,
patch("time.sleep"),
):
telemetry_manager_class()
info_displayed = False
for call in mock_warning.call_args_list:
print(f"call: {call}")
if "Telemetry is currently disabled by default" in str(call):
info_displayed = True
break
assert info_displayed
def test_is_whitelisted(telemetry_manager_class, mock_whitelist):
"""Test org whitelist functionality"""
with (
patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist),
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
):
manager = telemetry_manager_class()
# Should match organizations from the mock whitelist
assert manager._is_whitelisted("meta-llama/llama-7b")
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
# Should not match
assert not manager._is_whitelisted("unknown/model")
# Should handle case insensitively
assert manager._is_whitelisted("META-LLAMA/Llama-7B")
# Should handle empty input
assert not manager._is_whitelisted("")
def test_system_info_collection(manager):
"""Test system information collection"""
system_info = manager._get_system_info()
# Check essential keys
assert "os" in system_info
assert "python_version" in system_info
assert "cpu_count" in system_info
assert "memory_total" in system_info
assert "accelerator_count" in system_info
def test_send_event(telemetry_manager_class):
"""Test basic event sending"""
with (
patch("posthog.capture") as mock_capture,
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
):
manager = telemetry_manager_class()
# Test with clean properties (no PII)
manager.send_event("test_event", {"key": "value"})
assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "test_event"
assert mock_capture.call_args[1]["properties"] == {"key": "value"}
assert mock_capture.call_args[1]["distinct_id"] == manager.run_id
# Test with default properties (None)
mock_capture.reset_mock()
manager.send_event("simple_event")
assert mock_capture.called
assert mock_capture.call_args[1]["properties"] == {}
def test_send_system_info(telemetry_manager_class):
"""Test sending system info"""
with (
patch("posthog.capture") as mock_capture,
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
):
manager = telemetry_manager_class()
manager.send_system_info()
assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "system-info"
assert mock_capture.call_args[1]["properties"] == manager.system_info
def test_redacted_properties(telemetry_manager_class):
"""Test path redaction in send_event method"""
with (
patch("posthog.capture") as mock_capture,
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
):
manager = telemetry_manager_class()
# Test with properties containing various paths and non-paths
test_properties = {
"filepath": "/home/user/sensitive/data.txt",
"windows_path": "C:\\Users\\name\\Documents\\project\\file.py",
"output_dir": "/var/lib/data",
"path_to_model": "models/llama/7b",
"message": "Training started", # Should not be redacted
"metrics": {"loss": 0.5, "accuracy": 0.95}, # Should not be redacted
"base_model": "models/local_model",
"nested": {
"model_path": "/models/my_model",
"root_dir": "/home/user/projects",
"stats": {"steps": 1000, "epochs": 3}, # Should not be redacted
},
}
manager.send_event("test_event", test_properties)
# Verify the call was made
assert mock_capture.called
# Get the sanitized properties that were sent
sanitized = mock_capture.call_args[1]["properties"]
# Check that path-like and base_model keys were redacted
assert sanitized["filepath"] == "[REDACTED]"
assert sanitized["windows_path"] == "[REDACTED]"
assert sanitized["path_to_model"] == "[REDACTED]"
assert sanitized["base_model"] == "[REDACTED]"
# Check that non-path values were preserved
assert sanitized["message"] == "Training started"
assert sanitized["metrics"] == {"loss": 0.5, "accuracy": 0.95}
# Check nested structure handling
assert sanitized["nested"]["model_path"] == "[REDACTED]"
assert sanitized["nested"]["root_dir"] == "[REDACTED]"
assert sanitized["nested"]["stats"] == {"steps": 1000, "epochs": 3}
def test_disable_telemetry(manager):
"""Test that disabled telemetry doesn't send events"""
with patch("posthog.capture") as mock_capture:
manager.enabled = False
manager.send_event("test_event")
assert not mock_capture.called
def test_exception_handling_during_send(manager):
"""Test that exceptions in PostHog are handled gracefully"""
with (
patch("posthog.capture", side_effect=Exception("Test error")),
patch("logging.Logger.warning") as mock_warning,
):
manager.send_event("test_event")
warning_logged = False
for call in mock_warning.call_args_list:
if "Failed to send telemetry event" in str(call):
warning_logged = True
break
assert warning_logged
def test_shutdown(manager):
"""Test shutdown behavior"""
with patch("posthog.flush") as mock_flush:
manager.shutdown()
assert mock_flush.called

View File

@@ -0,0 +1,357 @@
"""Tests for runtime metrics telemetry module"""
# pylint: disable=redefined-outer-name
from unittest.mock import MagicMock, patch
import pytest
from axolotl.telemetry.runtime_metrics import RuntimeMetrics, RuntimeMetricsTracker
@pytest.fixture
def mock_time():
"""Mock time.time() to have predictable values in tests"""
with patch("time.time") as mock_time:
# Start with time 1000.0 and increment by 10 seconds on each call
times = [1000.0 + i * 10 for i in range(10)]
mock_time.side_effect = times
yield mock_time
@pytest.fixture
def mock_telemetry_manager():
"""Create a mock TelemetryManager"""
with patch(
"axolotl.telemetry.runtime_metrics.TelemetryManager"
) as mock_manager_class:
mock_manager = MagicMock()
mock_manager.enabled = True
mock_manager_class.get_instance.return_value = mock_manager
yield mock_manager
@pytest.fixture
def mock_psutil():
"""Mock psutil for memory information"""
with patch("axolotl.telemetry.runtime_metrics.psutil") as mock_psutil:
mock_process = MagicMock()
mock_memory_info = MagicMock()
# Set initial memory to 1GB
mock_memory_info.rss = 1024 * 1024 * 1024
mock_process.memory_info.return_value = mock_memory_info
mock_psutil.Process.return_value = mock_process
yield mock_psutil
@pytest.fixture
def mock_torch():
"""Mock torch.cuda functions"""
with patch("axolotl.telemetry.runtime_metrics.torch") as mock_torch:
mock_torch.cuda.is_available.return_value = True
mock_torch.cuda.device_count.return_value = 2
# Mock memory allocated per device (1GB for device 0, 2GB for device 1)
mock_torch.cuda.memory_allocated.side_effect = (
lambda device: (device + 1) * 1024 * 1024 * 1024
)
yield mock_torch
class TestRuntimeMetrics:
"""Tests for RuntimeMetrics class."""
def test_initialization(self):
"""Test RuntimeMetrics initialization."""
metrics = RuntimeMetrics(start_time=1000.0)
assert metrics.start_time == 1000.0
assert metrics.epoch_start_times == {}
assert metrics.epoch_end_times == {}
assert metrics.peak_gpu_memory == {}
assert metrics.total_steps == 0
assert metrics.current_epoch == 0
assert metrics.current_step == 0
assert metrics.peak_cpu_memory == 0
def test_elapsed_time(self, mock_time):
"""Test elapsed_time property."""
metrics = RuntimeMetrics(start_time=1000.0)
# Mock time.time() to return 1050.0
mock_time.side_effect = [1050.0]
assert metrics.elapsed_time == 50.0
def test_epoch_time(self):
"""Test epoch_time method."""
metrics = RuntimeMetrics(start_time=1000.0)
# No epoch data
assert metrics.epoch_time(0) is None
# Add epoch start but no end
metrics.epoch_start_times[0] = 1000.0
assert metrics.epoch_time(0) is None
# Add epoch end
metrics.epoch_end_times[0] = 1060.0
assert metrics.epoch_time(0) == 60.0
def test_average_epoch_time(self):
"""Test average_epoch_time method."""
metrics = RuntimeMetrics(start_time=1000.0)
# No completed epochs
assert metrics.average_epoch_time() is None
# Add one completed epoch
metrics.epoch_start_times[0] = 1000.0
metrics.epoch_end_times[0] = 1060.0
assert metrics.average_epoch_time() == 60.0
# Add second completed epoch
metrics.epoch_start_times[1] = 1060.0
metrics.epoch_end_times[1] = 1140.0 # 80 seconds
assert metrics.average_epoch_time() == 70.0 # Average of 60 and 80
# Add incomplete epoch (should not affect average)
metrics.epoch_start_times[2] = 1140.0
assert metrics.average_epoch_time() == 70.0
def test_steps_per_second(self, mock_time):
"""Test steps_per_second method."""
metrics = RuntimeMetrics(start_time=1000.0)
# No steps - first call to time.time()
mock_time.side_effect = None
mock_time.return_value = 1050.0
assert metrics.steps_per_second() is None
# Add steps - second call to time.time()
metrics.total_steps = 100
mock_time.return_value = 1050.0 # Keep same time for consistent result
assert metrics.steps_per_second() == 2.0 # 100 steps / 50 seconds
def test_to_dict_basic(self, mock_time):
"""Test to_dict method with basic metrics."""
metrics = RuntimeMetrics(start_time=1000.0)
metrics.total_steps = 100
metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024 # 2GB
# Mock elapsed_time
mock_time.side_effect = None
mock_time.return_value = 1050.0
result = metrics.to_dict()
assert result["total_time_seconds"] == 50.0
assert result["total_steps"] == 100
assert result["steps_per_second"] == 2.0
assert result["epochs_completed"] == 0
assert result["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
assert "epoch_times" not in result
assert "gpu_memory" not in result
def test_to_dict_with_epochs(self, mock_time):
"""Test to_dict method with epoch data."""
metrics = RuntimeMetrics(start_time=1000.0)
metrics.total_steps = 100
# Add epoch data
metrics.epoch_start_times[0] = 1000.0
metrics.epoch_end_times[0] = 1060.0
metrics.epoch_start_times[1] = 1060.0
metrics.epoch_end_times[1] = 1140.0
# Mock elapsed_time
mock_time.side_effect = None
mock_time.return_value = 1150.0
result = metrics.to_dict()
assert "epoch_times" in result
assert result["epoch_times"]["epoch_0_seconds"] == 60.0
assert result["epoch_times"]["epoch_1_seconds"] == 80.0
assert result["average_epoch_time_seconds"] == 70.0
def test_to_dict_with_gpu_memory(self, mock_time):
"""Test to_dict method with GPU memory data."""
metrics = RuntimeMetrics(start_time=1000.0)
metrics.peak_gpu_memory = {
0: 1 * 1024 * 1024 * 1024, # 1GB
1: 2 * 1024 * 1024 * 1024, # 2GB
}
# Mock elapsed_time
mock_time.side_effect = [1050.0]
result = metrics.to_dict()
assert "gpu_memory" in result
assert result["gpu_memory"]["gpu_0_peak_memory_bytes"] == 1 * 1024 * 1024 * 1024
assert result["gpu_memory"]["gpu_1_peak_memory_bytes"] == 2 * 1024 * 1024 * 1024
class TestRuntimeMetricsTracker:
"""Tests for RuntimeMetricsTracker class."""
# pylint: disable=unused-argument
def test_initialization(self, mock_time, mock_telemetry_manager):
"""Test RuntimeMetricsTracker initialization."""
tracker = RuntimeMetricsTracker()
assert isinstance(tracker.metrics, RuntimeMetrics)
assert tracker.metrics.start_time == 1000.0 # First value from mock_time
# pylint: disable=unused-argument
def test_start_epoch(
self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager
):
"""Test start_epoch method."""
tracker = RuntimeMetricsTracker()
# Reset mock_time to control next value
mock_time.side_effect = [1010.0]
tracker.start_epoch(0)
assert tracker.metrics.current_epoch == 0
assert tracker.metrics.epoch_start_times[0] == 1010.0
# Verify memory metrics were updated
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
assert 0 in tracker.metrics.peak_gpu_memory
assert 1 in tracker.metrics.peak_gpu_memory
# pylint: disable=unused-argument
def test_end_epoch(self, mock_time, mock_telemetry_manager):
"""Test end_epoch method."""
tracker = RuntimeMetricsTracker()
# Start epoch 0
mock_time.side_effect = [1010.0]
tracker.start_epoch(0)
# End epoch 0
mock_time.side_effect = [1060.0]
tracker.end_epoch(0)
assert 0 in tracker.metrics.epoch_end_times
assert tracker.metrics.epoch_end_times[0] == 1060.0
# pylint: disable=unused-argument
def test_update_step(
self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager
):
"""Test update_step method."""
tracker = RuntimeMetricsTracker()
# Update step to a non-multiple of 100
tracker.update_step(42)
assert tracker.metrics.current_step == 42
assert tracker.metrics.total_steps == 1
# Memory metrics should not be updated for non-multiple of 100
assert tracker.metrics.peak_cpu_memory == 0
# Update step to a multiple of 100
tracker.update_step(100)
assert tracker.metrics.current_step == 100
assert tracker.metrics.total_steps == 2
# Memory metrics should be updated for multiple of 100
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
# pylint: disable=unused-argument
def test_update_memory_metrics(
self, mock_psutil, mock_torch, mock_telemetry_manager
):
"""Test update_memory_metrics method."""
tracker = RuntimeMetricsTracker()
# Initial memory state
assert tracker.metrics.peak_cpu_memory == 0
assert tracker.metrics.peak_gpu_memory == {}
# Update memory metrics
tracker.update_memory_metrics()
# Verify CPU memory
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
# Verify GPU memory
assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024
assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024
# Change mocked memory values to be lower
mock_process = mock_psutil.Process.return_value
mock_memory_info = mock_process.memory_info.return_value
mock_memory_info.rss = 0.5 * 1024 * 1024 * 1024 # 0.5GB
mock_torch.cuda.memory_allocated.side_effect = (
lambda device: (device + 0.5) * 1024 * 1024 * 1024
)
# Update memory metrics again
tracker.update_memory_metrics()
# Peak values should not decrease
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024
assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024
# Change mocked memory values to be higher
mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB
mock_torch.cuda.memory_allocated.side_effect = (
lambda device: (device + 2) * 1024 * 1024 * 1024
)
# Update memory metrics again
tracker.update_memory_metrics()
# Peak values should increase
assert tracker.metrics.peak_cpu_memory == 2 * 1024 * 1024 * 1024
assert tracker.metrics.peak_gpu_memory[0] == 2 * 1024 * 1024 * 1024
assert tracker.metrics.peak_gpu_memory[1] == 3 * 1024 * 1024 * 1024
# pylint: disable=unused-argument
def test_get_memory_metrics(self, mock_psutil, mock_torch, mock_telemetry_manager):
"""Test get_memory_metrics method."""
tracker = RuntimeMetricsTracker()
# Set peak memory values
tracker.metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024
tracker.metrics.peak_gpu_memory = {
0: 3 * 1024 * 1024 * 1024,
1: 4 * 1024 * 1024 * 1024,
}
# Get memory metrics
memory_metrics = tracker.get_memory_metrics()
# Verify CPU memory
assert (
memory_metrics["cpu_memory_bytes"] == 1 * 1024 * 1024 * 1024
) # Current value from mock
assert (
memory_metrics["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
) # Peak value we set
# Verify GPU memory
assert (
memory_metrics["gpu_0_memory_bytes"] == 1 * 1024 * 1024 * 1024
) # Current value from mock
assert (
memory_metrics["gpu_0_peak_memory_bytes"] == 3 * 1024 * 1024 * 1024
) # Peak value we set
assert (
memory_metrics["gpu_1_memory_bytes"] == 2 * 1024 * 1024 * 1024
) # Current value from mock
assert (
memory_metrics["gpu_1_peak_memory_bytes"] == 4 * 1024 * 1024 * 1024
) # Peak value we set