Compare commits
33 Commits
sdpa-cp
...
telemetry-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
345a159796 | ||
|
|
657bffd85f | ||
|
|
f0dde8e2d5 | ||
|
|
25fa4df70f | ||
|
|
e735f4270b | ||
|
|
035e7a2f4c | ||
|
|
2d36c11264 | ||
|
|
b8ec5bdccf | ||
|
|
249405b46e | ||
|
|
d3be84fec2 | ||
|
|
1c74ab175f | ||
|
|
b2f1fc109a | ||
|
|
5a2a80cc48 | ||
|
|
4033fe74f8 | ||
|
|
e9df4444be | ||
|
|
ffd2985750 | ||
|
|
17310f9acc | ||
|
|
71ae6f9f87 | ||
|
|
9dd1092f8f | ||
|
|
2c2f2647a9 | ||
|
|
98313a6b3f | ||
|
|
8b75205d3b | ||
|
|
ef4990f304 | ||
|
|
db3297b090 | ||
|
|
86ed554bda | ||
|
|
f254d7d5a2 | ||
|
|
d8b0522ea0 | ||
|
|
1edd6b9524 | ||
|
|
66c6fb56cb | ||
|
|
90b39ce112 | ||
|
|
5afab46cc6 | ||
|
|
bd152c6115 | ||
|
|
76336743ff |
2
.github/workflows/multi-gpu-e2e.yml
vendored
2
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
- 'setup.py'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/multi-gpu-e2e.yml'
|
||||
- 'src/axolotl/core/trainers/mixins/context_parallel.py'
|
||||
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
|
||||
- 'src/axolotl/utils/distributed.py'
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
|
||||
@@ -112,6 +112,13 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
|
||||
|
||||
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
|
||||
|
||||
## 📈 Telemetry
|
||||
|
||||
Axolotl has opt-in telemetry that helps us understand how the project is being used
|
||||
and prioritize improvements. We collect basic system information, model types, and
|
||||
error rates—never personal data or file paths. Telemetry is disabled by default. To
|
||||
enable it, set AXOLOTL_DO_NOT_TRACK=0. For more details, see our [telemetry documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html).
|
||||
|
||||
## Supported Models
|
||||
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||
|
||||
@@ -75,7 +75,7 @@ quartodoc:
|
||||
- title: Context Managers
|
||||
desc: Context managers for altering trainer behaviors
|
||||
contents:
|
||||
- utils.ctx_managers.context_parallel
|
||||
- utils.ctx_managers.sequence_parallel
|
||||
- title: Prompt Strategies
|
||||
desc: Prompt formatting strategies
|
||||
contents:
|
||||
@@ -236,6 +236,7 @@ website:
|
||||
- docs/inference.qmd
|
||||
- docs/cli.qmd
|
||||
- docs/config.qmd
|
||||
- docs/telemetry.qmd
|
||||
- text: "API Reference"
|
||||
href: docs/api
|
||||
|
||||
@@ -274,7 +275,7 @@ website:
|
||||
- docs/unsloth.qmd
|
||||
- docs/torchao.qmd
|
||||
- docs/custom_integrations.qmd
|
||||
- docs/context_parallelism.qmd
|
||||
- docs/sequence_parallelism.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
|
||||
@@ -764,13 +764,13 @@ ddp_timeout:
|
||||
ddp_bucket_cap_mb:
|
||||
ddp_broadcast_buffers:
|
||||
|
||||
# Context parallelism
|
||||
# Sequence parallelism
|
||||
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
|
||||
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
|
||||
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||
# See https://docs.axolotl.ai/docs/context_parallelism.html for more details.
|
||||
context_parallel_degree:
|
||||
# See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details.
|
||||
sequence_parallel_degree:
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
# Must evenly divide the number of KV heads in your model.
|
||||
heads_k_stride: 1
|
||||
|
||||
@@ -18,7 +18,7 @@ Axolotl supports several methods for multi-GPU training:
|
||||
|
||||
- DeepSpeed (recommended)
|
||||
- FSDP (Fully Sharded Data Parallel)
|
||||
- Context parallelism
|
||||
- Sequence parallelism
|
||||
- FSDP + QLoRA
|
||||
|
||||
## DeepSpeed {#sec-deepspeed}
|
||||
@@ -80,14 +80,14 @@ fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
## Context parallelism {#sec-sequence-parallelism}
|
||||
## Sequence parallelism {#sec-sequence-parallelism}
|
||||
|
||||
We support context parallelism (SP) via the
|
||||
We support sequence parallelism (SP) via the
|
||||
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
|
||||
allows one to split up sequences across GPUs, which is useful in the event that a
|
||||
single sequence causes OOM errors during model training.
|
||||
|
||||
See our [dedicated guide](context_parallelism.qmd) for more information.
|
||||
See our [dedicated guide](sequence_parallelism.qmd) for more information.
|
||||
|
||||
### FSDP + QLoRA {#sec-fsdp-qlora}
|
||||
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
---
|
||||
title: Context Parallelism
|
||||
title: Sequence Parallelism
|
||||
description: Train with long sequences split across multiple GPUs.
|
||||
---
|
||||
|
||||
Context parallelism is a technique that splits sequences across multiple GPUs,
|
||||
Sequence parallelism is a technique that splits sequences across multiple GPUs,
|
||||
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
|
||||
GPU processes a different portion of the sequence, and the results are aggregated
|
||||
through a ring communication pattern.
|
||||
|
||||
## When to Use Context Parallelism
|
||||
## When to Use Sequence Parallelism
|
||||
|
||||
Use context parallelism when:
|
||||
Use sequence parallelism when:
|
||||
|
||||
- You need to train with sequence lengths that don't fit into a single GPU's memory
|
||||
- You have multiple GPUs available
|
||||
@@ -18,11 +18,11 @@ Use context parallelism when:
|
||||
|
||||
## Configuration
|
||||
|
||||
To enable context parallelism, add the following to your configuration file:
|
||||
To enable sequence parallelism, add the following to your configuration file:
|
||||
|
||||
```yaml
|
||||
# Set to a divisor (> 1) of the number of GPUs available
|
||||
context_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||
@@ -30,23 +30,23 @@ heads_k_stride: 1
|
||||
ring_attn_func:
|
||||
```
|
||||
|
||||
The `context_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
- With 8 GPUs, valid values would be 2, 4, or 8
|
||||
- With 4 GPUs, valid values would be 2 or 4
|
||||
|
||||
## Implementation Details
|
||||
|
||||
When context parallelism is enabled:
|
||||
When sequence parallelism is enabled:
|
||||
|
||||
1. Each sequence is divided into equal chunks across the GPUs in a context parallel group
|
||||
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
||||
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
||||
3. Position IDs are adjusted to maintain proper relative positions
|
||||
4. The trainer uses special ring communication patterns for attention operations
|
||||
|
||||
## Requirements
|
||||
|
||||
To use context parallelism, you need:
|
||||
To use sequence parallelism, you need:
|
||||
|
||||
- Multiple GPUs (at least 2)
|
||||
- The `ring-flash-attn` package. Install with:
|
||||
@@ -66,7 +66,7 @@ sequence_len: 8192
|
||||
|
||||
...
|
||||
|
||||
context_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||
@@ -79,22 +79,22 @@ ring_attn_func:
|
||||
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
||||
into 2 subsequences of length 4096 across 2 GPUs.
|
||||
|
||||
## Sample Packing with Context Parallelism
|
||||
## Sample Packing with Sequence Parallelism
|
||||
|
||||
Context parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
||||
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
||||
|
||||
1. Samples are first packed together
|
||||
2. The packed sequences are then divided across GPUs in the context parallel group
|
||||
2. The packed sequences are then divided across GPUs in the sequence parallel group
|
||||
3. Position IDs are automatically adjusted to maintain proper relative positions
|
||||
|
||||
## Effect on Batch Size
|
||||
|
||||
When using context parallelism, your effective global batch size is **divided** by the `context_parallel_degree`. This happens because:
|
||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||
|
||||
- Each group of `context_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||
- The number of batches processed per step decreases
|
||||
|
||||
For example:
|
||||
- With 8 GPUs and no context parallelism: 8 different batches processed per step
|
||||
- With 8 GPUs and `context_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||
|
||||
59
docs/telemetry.qmd
Normal file
59
docs/telemetry.qmd
Normal file
@@ -0,0 +1,59 @@
|
||||
---
|
||||
title: Telemetry
|
||||
description: A description of the opt-in telemetry implementation in Axolotl.
|
||||
---
|
||||
|
||||
# Telemetry in Axolotl
|
||||
|
||||
Axolotl implements anonymous telemetry to help maintainers understand how the library
|
||||
is used and where users encounter issues. This data helps prioritize features, optimize
|
||||
performance, and fix bugs.
|
||||
|
||||
## Data Collection
|
||||
|
||||
We collect:
|
||||
|
||||
- System info: OS, Python version, Axolotl version, PyTorch version, Transformers
|
||||
version, etc.
|
||||
- Hardware info: CPU count, memory, GPU count and models
|
||||
- Runtime metrics: Training progress, memory usage, timing information
|
||||
- Usage patterns: Models (from a whitelist) and configurations used
|
||||
- Error tracking: Stack traces and error messages (sanitized to remove personal
|
||||
information)
|
||||
|
||||
Personally identifiable information (PII) is not collected.
|
||||
|
||||
## Implementation
|
||||
|
||||
Telemetry is implemented using PostHog and consists of:
|
||||
|
||||
- `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the
|
||||
telemetry system and provides methods for tracking events.
|
||||
- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and
|
||||
sends sanitized stack traces.
|
||||
- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks
|
||||
runtime metrics during training.
|
||||
- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends
|
||||
runtime metrics telemetry.
|
||||
|
||||
The telemetry system will block training startup for 15 seconds to ensure users are
|
||||
aware of data collection, unless telemetry is explicitly enabled or disabled.
|
||||
|
||||
## Opt-In Mechanism
|
||||
|
||||
Telemetry is **disabled by default** on an opt-in basis. To enable it, set `AXOLOTL_DO_NOT_TRACK=0`.
|
||||
|
||||
To remove the warning message about telemetry that is displayed on train, etc. startup,
|
||||
explicitly set: `AXOLOTL_DO_NOT_TRACK=0` (enable telemetry) or `AXOLOTL_DO_NOT_TRACK=1`
|
||||
(explicitly disable telemetry).
|
||||
|
||||
**Note**: Telemetry will move to an opt-out model in a later release.
|
||||
|
||||
## Privacy
|
||||
|
||||
- All path-like config information is automatically redacted from telemetry data
|
||||
- Model information is only collected for whitelisted organizations
|
||||
- See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations
|
||||
- Each run generates a unique anonymous ID
|
||||
- This allows us to link different telemetry events in a single same training run
|
||||
- Telemetry is only sent from the main process to avoid duplicate events
|
||||
@@ -67,3 +67,6 @@ schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.3
|
||||
|
||||
# telemetry
|
||||
posthog>=4.2.0
|
||||
|
||||
@@ -14,6 +14,8 @@ import yaml
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
from axolotl.utils.config import (
|
||||
normalize_cfg_datasets,
|
||||
@@ -28,6 +30,8 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
|
||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||
|
||||
|
||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||
"""
|
||||
@@ -159,6 +163,7 @@ def plugin_set_cfg(cfg: DictDefault):
|
||||
plugin_manager.cfg = cfg
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_cfg(
|
||||
config: str | Path | DictDefault = Path("examples/"), **kwargs
|
||||
) -> DictDefault:
|
||||
@@ -192,6 +197,8 @@ def load_cfg(
|
||||
temp_file.close()
|
||||
cfg.axolotl_config_path = temp_file.name
|
||||
|
||||
TELEMETRY_MANAGER.send_event(event_type="config-loaded", properties=cfg)
|
||||
|
||||
# If there are any options passed in the cli, if it is something that seems valid
|
||||
# from the yaml, then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
@@ -233,4 +240,6 @@ def load_cfg(
|
||||
setup_comet_env_vars(cfg)
|
||||
plugin_set_cfg(cfg)
|
||||
|
||||
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
@@ -16,6 +16,7 @@ from axolotl.cli.args import InferenceCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.chat_templates import (
|
||||
get_chat_template,
|
||||
get_chat_template_from_config,
|
||||
@@ -42,6 +43,7 @@ def get_multi_line_input() -> str:
|
||||
return instruction
|
||||
|
||||
|
||||
@send_errors
|
||||
def do_inference(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
@@ -135,6 +137,7 @@ def do_inference(
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
@send_errors
|
||||
def do_inference_gradio(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
|
||||
@@ -9,12 +9,14 @@ from dotenv import load_dotenv
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@send_errors
|
||||
def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
"""
|
||||
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
|
||||
@@ -73,7 +75,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
load_in_8bit=False,
|
||||
load_in_4bit=False,
|
||||
flash_attention=False,
|
||||
context_parallel_degree=None,
|
||||
sequence_parallel_degree=None,
|
||||
deepspeed=None,
|
||||
fsdp=None,
|
||||
fsdp_config=None,
|
||||
|
||||
@@ -24,6 +24,7 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -118,6 +119,7 @@ def _distributed_checkpoint_to_merged_weights(
|
||||
return save_path_
|
||||
|
||||
|
||||
@send_errors
|
||||
def merge_fsdp_weights(
|
||||
checkpoint_dir: str,
|
||||
output_path: str,
|
||||
|
||||
@@ -18,6 +18,7 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.trainer import disable_datasets_caching
|
||||
@@ -25,6 +26,7 @@ from axolotl.utils.trainer import disable_datasets_caching
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@send_errors
|
||||
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||
"""
|
||||
Preprocesses dataset specified in axolotl config.
|
||||
|
||||
@@ -10,6 +10,7 @@ from datasets import Dataset
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -45,6 +46,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
||||
)
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
@@ -112,6 +114,7 @@ def load_datasets(
|
||||
)
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_preference_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
|
||||
@@ -31,6 +31,8 @@ from transformers.training_args import OptimizerNames
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||
from axolotl.telemetry.callbacks import TelemetryCallback
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
@@ -145,6 +147,10 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
callbacks.append(GPUStatsCallback(cfg=self.cfg))
|
||||
|
||||
telemetry_manager = TelemetryManager.get_instance()
|
||||
if telemetry_manager.enabled:
|
||||
callbacks.append(TelemetryCallback())
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
|
||||
@@ -54,7 +54,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
context_parallel=self.cfg.context_parallel_degree > 1
|
||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOContextParallelTrainer, AxolotlGRPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .relora import ReLoRATrainer
|
||||
from .trl import (
|
||||
|
||||
@@ -7,13 +7,11 @@ from __future__ import annotations
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
from typing import Callable, Literal, Optional
|
||||
|
||||
from axolotl.utils.ctx_managers.context_parallel.distributed import get_context_parallel_manager
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from torch import nn
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
DataLoader,
|
||||
@@ -67,32 +65,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
# SPDA device mesh init
|
||||
import torch.distributed as dist
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
mesh_shape = (
|
||||
world_size // 2,
|
||||
2,
|
||||
)
|
||||
self.world_mesh = dist.DeviceMesh(
|
||||
"cuda",
|
||||
torch.tensor(list(range(world_size))).reshape(mesh_shape),
|
||||
mesh_dim_names=("dp", "cp"),
|
||||
)
|
||||
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
|
||||
) -> torch.Tensor:
|
||||
ctx_manager = get_context_parallel_manager(
|
||||
world_mesh=self.world_mesh,
|
||||
model=model,
|
||||
)
|
||||
to_shard = {k: v for k, v in inputs.items() if v.ndim > 1}
|
||||
with ctx_manager(list(to_shard.values())):
|
||||
super().training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
|
||||
@@ -8,7 +8,7 @@ from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||
from axolotl.core.trainers.grpo.trainer import (
|
||||
AxolotlGRPOContextParallelTrainer,
|
||||
AxolotlGRPOSequenceParallelTrainer,
|
||||
AxolotlGRPOTrainer,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -23,10 +23,10 @@ class GRPOStrategy:
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(
|
||||
cls, context_parallel: bool
|
||||
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOContextParallelTrainer]:
|
||||
if context_parallel:
|
||||
return AxolotlGRPOContextParallelTrainer
|
||||
cls, sequence_parallel: bool
|
||||
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]:
|
||||
if sequence_parallel:
|
||||
return AxolotlGRPOSequenceParallelTrainer
|
||||
return AxolotlGRPOTrainer
|
||||
|
||||
@classmethod
|
||||
@@ -69,8 +69,8 @@ class GRPOStrategy:
|
||||
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||
|
||||
if cfg.context_parallel_degree > 1:
|
||||
grpo_args_kwargs["context_parallel_degree"] = cfg.context_parallel_degree
|
||||
if cfg.sequence_parallel_degree > 1:
|
||||
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
|
||||
|
||||
if trl.reward_weights:
|
||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||
|
||||
@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""Axolotl GRPO Config for GRPO training"""
|
||||
|
||||
context_parallel_degree: int | None = None
|
||||
sequence_parallel_degree: int | None = None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Repeat random sampler (similar to the one implemented in
|
||||
https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds
|
||||
context parallelism functionality; i.e., duplicating data across ranks in the same
|
||||
context parallel group.
|
||||
sequence parallelism functionality; i.e., duplicating data across ranks in the same
|
||||
sequence parallel group.
|
||||
"""
|
||||
|
||||
from typing import Iterator, Sized
|
||||
@@ -10,26 +10,26 @@ import torch
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
|
||||
class ContextParallelRepeatRandomSampler(Sampler):
|
||||
"""Sampler for GRPO training with context parallelism.
|
||||
class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
"""Sampler for GRPO training with sequence parallelism.
|
||||
|
||||
This sampler ensures:
|
||||
- Ranks in the same context parallel (SP) group receive identical data.
|
||||
- Ranks in the same sequence parallel (SP) group receive identical data.
|
||||
- Each index is repeated multiple times for sampling different completions.
|
||||
- Entire batches are repeated for reuse in multiple updates.
|
||||
- Data is properly distributed across CP groups.
|
||||
- Data is properly distributed across SP groups.
|
||||
|
||||
In the table below, the values represent dataset indices. Each CP group has
|
||||
`context_parallel_degree = 2` GPUs working together on the same data. There are 2
|
||||
CP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
||||
In the table below, the values represent dataset indices. Each SP group has
|
||||
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
|
||||
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
||||
|
||||
Context Parallel Groups
|
||||
Sequence Parallel Groups
|
||||
| SP0 | SP1 |
|
||||
| GPU 0 | GPU 1 | GPU 2 | GPU 3 |
|
||||
global_step step <---> mini_repeat_count=3
|
||||
<----------> batch_size=2 per CP group
|
||||
grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- CP groups get different data
|
||||
▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each CP group GPU
|
||||
<----------> batch_size=2 per SP group
|
||||
grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- SP groups get different data
|
||||
▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each SP group GPU
|
||||
|
|
||||
| 1 2 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Repeat same indices for iterations
|
||||
num_iterations=2 ▼ 1 3 [0 0 0 1 1 1] [2 2 2 3 3 3] <- When using gradient accumulation
|
||||
@@ -45,7 +45,7 @@ class ContextParallelRepeatRandomSampler(Sampler):
|
||||
rank: Rank of current process.
|
||||
batch_size: Number of samples per batch.
|
||||
repeat_count: How many times to repeat the full sampling process.
|
||||
context_parallel_degree: Number of ranks in a context parallel group.
|
||||
sequence_parallel_degree: Number of ranks in a sequence parallel group.
|
||||
shuffle: Whether to shuffle the dataset.
|
||||
seed: Random seed for shuffling.
|
||||
drop_last: Whether to drop the last incomplete batch.
|
||||
@@ -59,7 +59,7 @@ class ContextParallelRepeatRandomSampler(Sampler):
|
||||
rank: int,
|
||||
batch_size: int = 1,
|
||||
repeat_count: int = 1,
|
||||
context_parallel_degree: int = 1,
|
||||
sequence_parallel_degree: int = 1,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
@@ -76,16 +76,16 @@ class ContextParallelRepeatRandomSampler(Sampler):
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
|
||||
# Context parallelism parameters
|
||||
self.context_parallel_degree = context_parallel_degree
|
||||
self.num_sp_groups = world_size // context_parallel_degree
|
||||
self.sp_group_id = rank // context_parallel_degree
|
||||
# Sequence parallelism parameters
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.num_sp_groups = world_size // sequence_parallel_degree
|
||||
self.sp_group_id = rank // sequence_parallel_degree
|
||||
|
||||
# Adjust dataset size for distributed sampling
|
||||
self.num_samples = len(self.dataset)
|
||||
self.total_size = self.num_samples
|
||||
|
||||
# Calculate effective number of samples per CP group
|
||||
# Calculate effective number of samples per SP group
|
||||
if (
|
||||
self.drop_last
|
||||
and self.total_size % (self.num_sp_groups * self.batch_size) != 0
|
||||
@@ -125,8 +125,8 @@ class ContextParallelRepeatRandomSampler(Sampler):
|
||||
padding = indices[: self.batch_size - len(indices) % self.batch_size]
|
||||
indices += padding
|
||||
|
||||
# Subsample based on CP group ID
|
||||
# Each CP group gets distinct batches of data
|
||||
# Subsample based on SP group ID
|
||||
# Each SP group gets distinct batches of data
|
||||
batch_indices = []
|
||||
for i in range(0, len(indices), self.batch_size * self.num_sp_groups):
|
||||
start_idx = i + self.sp_group_id * self.batch_size
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Axolotl GRPO trainers (with and without context parallelism handling)"""
|
||||
"""Axolotl GRPO trainers (with and without sequence parallelism handling)"""
|
||||
|
||||
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
||||
|
||||
@@ -41,7 +41,7 @@ from trl.trainer.grpo_config import GRPOConfig
|
||||
from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
||||
from trl.trainer.utils import pad
|
||||
|
||||
from axolotl.core.trainers.grpo.sampler import ContextParallelRepeatRandomSampler
|
||||
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
|
||||
@@ -59,8 +59,8 @@ class AxolotlGRPOTrainer(
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
|
||||
class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for context parallelism handling"""
|
||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -97,11 +97,11 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
|
||||
)
|
||||
|
||||
# Get number of CP groups (number of processes divided by CP degree)
|
||||
# Get number of SP groups (number of processes divided by SP degree)
|
||||
num_processes = self.accelerator.num_processes
|
||||
num_sp_groups = num_processes // self.args.context_parallel_degree
|
||||
num_sp_groups = num_processes // self.args.sequence_parallel_degree
|
||||
|
||||
# Calculate batch size per CP group (not per process)
|
||||
# Calculate batch size per SP group (not per process)
|
||||
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
|
||||
possible_values = [
|
||||
n_gen
|
||||
@@ -111,7 +111,7 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"The batch size per CP group ({num_sp_groups} x "
|
||||
f"The batch size per SP group ({num_sp_groups} x "
|
||||
f"{self.args.per_device_train_batch_size}) must be evenly divisible by "
|
||||
f"the number of generations per prompt ({self.num_generations}). Given "
|
||||
"the current configuration, the valid values for the number of "
|
||||
@@ -119,7 +119,7 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
)
|
||||
|
||||
if self.args.eval_strategy != "no":
|
||||
# If context parallelism is enabled, calculate batch size per CP group
|
||||
# If sequence parallelism is enabled, calculate batch size per SP group
|
||||
sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups # type: ignore[union-attr]
|
||||
possible_values = [
|
||||
n_gen
|
||||
@@ -129,8 +129,8 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"With context parallelism (degree {self.args.context_parallel_degree}), "
|
||||
f"the eval batch size per CP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
||||
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
|
||||
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
||||
f"must be evenly divisible by the number of generations per prompt "
|
||||
f"({self.num_generations}). Given the current eval batch size, "
|
||||
f"the valid values for the number of generations are: {possible_values}."
|
||||
@@ -143,7 +143,7 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
self.local_world_size = 1
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
# Initialize the CP group
|
||||
# Initialize the SP group
|
||||
self.sp_group = get_ring_attn_group()
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
@@ -159,16 +159,16 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
* self.args.gradient_accumulation_steps
|
||||
)
|
||||
|
||||
return ContextParallelRepeatRandomSampler(
|
||||
return SequenceParallelRepeatRandomSampler(
|
||||
dataset=self.train_dataset,
|
||||
mini_repeat_count=self.num_generations,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
batch_size=effective_batch_size
|
||||
// self.num_generations
|
||||
// self.args.context_parallel_degree,
|
||||
// self.args.sequence_parallel_degree,
|
||||
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
|
||||
context_parallel_degree=self.args.context_parallel_degree,
|
||||
sequence_parallel_degree=self.args.sequence_parallel_degree,
|
||||
shuffle=True,
|
||||
seed=self.args.seed,
|
||||
drop_last=True,
|
||||
@@ -226,11 +226,11 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
):
|
||||
self.accelerator.even_batches = False
|
||||
|
||||
# Return unprepared dataloader if using context parallelism
|
||||
# Return unprepared dataloader if using sequence parallelism
|
||||
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||
# slice each batch along the sequence dimension).
|
||||
if self.args.context_parallel_degree > 1:
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
return dataloader
|
||||
|
||||
# Otherwise prepare with accelerator
|
||||
@@ -303,21 +303,21 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
if self.accelerator.is_main_process:
|
||||
if self.args.context_parallel_degree > 1:
|
||||
# Calculate context parallel group information
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Calculate sequence parallel group information
|
||||
world_size = self.accelerator.num_processes
|
||||
context_parallel_degree = self.args.context_parallel_degree
|
||||
num_sp_groups = world_size // context_parallel_degree
|
||||
sequence_parallel_degree = self.args.sequence_parallel_degree
|
||||
num_sp_groups = world_size // sequence_parallel_degree
|
||||
|
||||
# Since processes in the same CP group have the same prompts, we need to ensure
|
||||
# we only take one copy of each prompt from each CP group
|
||||
# Since processes in the same SP group have the same prompts, we need to ensure
|
||||
# we only take one copy of each prompt from each SP group
|
||||
ordered_set_of_prompts = []
|
||||
for sp_group_id in range(num_sp_groups):
|
||||
# Get the first process from each CP group (typically the group leader)
|
||||
group_leader_rank = sp_group_id * context_parallel_degree
|
||||
# Get the first process from each SP group (typically the group leader)
|
||||
group_leader_rank = sp_group_id * sequence_parallel_degree
|
||||
|
||||
# Extract prompts from this CP group, accounting for num_generations duplicates
|
||||
# We only need prompts from one rank in each CP group
|
||||
# Extract prompts from this SP group, accounting for num_generations duplicates
|
||||
# We only need prompts from one rank in each SP group
|
||||
group_prompts = all_prompts_text[
|
||||
group_leader_rank
|
||||
* len(prompts_text) : (group_leader_rank + 1)
|
||||
@@ -330,7 +330,7 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||
# prompt individually.
|
||||
ordered_set_of_prompts = all_prompts_text[
|
||||
:: self.num_generations * self.args.context_parallel_degree
|
||||
:: self.num_generations * self.args.sequence_parallel_degree
|
||||
]
|
||||
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
@@ -347,28 +347,28 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
)
|
||||
else:
|
||||
completion_ids = [None] * (
|
||||
len(all_prompts_text) // self.args.context_parallel_degree
|
||||
len(all_prompts_text) // self.args.sequence_parallel_degree
|
||||
)
|
||||
|
||||
# Broadcast the completions from the main process to all processes
|
||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||
|
||||
# Determine the appropriate slice based on context parallelism
|
||||
if self.args.context_parallel_degree > 1:
|
||||
# Calculate CP group ID (which group of ranks this rank belongs to)
|
||||
# Determine the appropriate slice based on sequence parallelism
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
# Calculate the start index for this CP group
|
||||
# Calculate the start index for this SP group
|
||||
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
|
||||
|
||||
# All ranks in the same CP group get the same data slice
|
||||
# All ranks in the same SP group get the same data slice
|
||||
process_slice = slice(
|
||||
sp_group_start,
|
||||
sp_group_start + len(prompts),
|
||||
)
|
||||
completion_ids = completion_ids[process_slice]
|
||||
else:
|
||||
# Original behavior for non-context parallel case
|
||||
# Original behavior for non-sequence parallel case
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts),
|
||||
(self.accelerator.process_index + 1) * len(prompts),
|
||||
@@ -578,20 +578,20 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
advantages = advantages / (std_grouped_rewards + 1e-4)
|
||||
|
||||
# Slice to keep only the local part of the data
|
||||
if self.args.context_parallel_degree > 1:
|
||||
# Calculate CP group ID (which group of ranks this rank belongs to)
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
# Calculate the start index for this CP group
|
||||
# Calculate the start index for this SP group
|
||||
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
|
||||
|
||||
# All ranks in the same CP group get the same data slice
|
||||
# All ranks in the same SP group get the same data slice
|
||||
process_slice = slice(
|
||||
sp_group_start,
|
||||
sp_group_start + len(prompts),
|
||||
)
|
||||
else:
|
||||
# Original behavior for non-context parallel case
|
||||
# Original behavior for non-sequence parallel case
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts),
|
||||
(self.accelerator.process_index + 1) * len(prompts),
|
||||
|
||||
@@ -11,6 +11,7 @@ from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.train import (
|
||||
TrainDatasetMeta,
|
||||
setup_model_and_tokenizer,
|
||||
@@ -63,6 +64,7 @@ def evaluate_dataset(
|
||||
return metrics
|
||||
|
||||
|
||||
@send_errors
|
||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate a model on training and validation datasets.
|
||||
|
||||
@@ -19,6 +19,7 @@ from peft import (
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from axolotl.loaders.utils import get_linear_embedding_layers
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -162,6 +163,7 @@ def load_lora(
|
||||
return model, lora_config
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_adapter(
|
||||
model: PreTrainedModel,
|
||||
cfg: DictDefault,
|
||||
|
||||
@@ -46,6 +46,7 @@ from axolotl.loaders.utils import (
|
||||
load_model_config,
|
||||
)
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import (
|
||||
@@ -145,6 +146,7 @@ class ModelLoader:
|
||||
"""Property that determines if FSDP with QLoRA is enabled."""
|
||||
return self.cfg.fsdp and self.cfg.adapter == "qlora"
|
||||
|
||||
@send_errors
|
||||
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
|
||||
"""Load and prepare the model with all configurations and patches.
|
||||
|
||||
|
||||
@@ -8,12 +8,14 @@ from transformers import (
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
processor_kwargs: dict[str, Any] = {} # Do we actually need this?
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from transformers import (
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.distributed import (
|
||||
barrier,
|
||||
@@ -117,6 +118,7 @@ def modify_tokenizer_files(
|
||||
return tokenizer_dir
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_tokenizer(cfg):
|
||||
"""Load and configure the tokenizer based on the provided config."""
|
||||
model_config = load_model_config(cfg)
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
|
||||
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
|
||||
their context parallel version of Flash Attention 2.
|
||||
their sequence parallel version of Flash Attention 2.
|
||||
|
||||
We also provide some patches for accelerate functions to prepare the dataloader for
|
||||
context parallelism training.
|
||||
sequence parallelism training.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
@@ -13,9 +13,9 @@ import inspect
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -63,15 +63,15 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||
|
||||
|
||||
def register_ring_attn(
|
||||
context_parallel_degree: int,
|
||||
sequence_parallel_degree: int,
|
||||
heads_k_stride: int | None,
|
||||
ring_attn_func: RingAttnFunc | None,
|
||||
):
|
||||
"""Create ring attention group and substitute flash attn with ring flash attn.
|
||||
|
||||
Args:
|
||||
context_parallel_degree: Context parallelism factor.
|
||||
heads_k_stride: Context parallelism K head stride size. Passed through to
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||
`varlen_llama3` `ring_flash_attn` implementation.
|
||||
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
||||
packing is enabled, it must be a `varlen` function; otherwise, it must be a
|
||||
@@ -80,18 +80,28 @@ def register_ring_attn(
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
LOG.info(
|
||||
"Enabling ring attention context parallelism: "
|
||||
f"each sequence will be processed across {context_parallel_degree} GPUs"
|
||||
if rank == 0:
|
||||
LOG.info(
|
||||
"Enabling ring attention sequence parallelism: "
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||
)
|
||||
|
||||
assert sequence_parallel_degree <= world_size, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must be less than or equal to world_size ({world_size})"
|
||||
)
|
||||
assert world_size % sequence_parallel_degree == 0, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must evenly divide world_size ({world_size})"
|
||||
)
|
||||
|
||||
# Assign ranks to context parallel groups
|
||||
# Assign ranks to sequence parallel groups
|
||||
group_assignments = {}
|
||||
for i in range(world_size // context_parallel_degree):
|
||||
for i in range(world_size // sequence_parallel_degree):
|
||||
ring_attn_ranks = list(
|
||||
range(
|
||||
i * context_parallel_degree,
|
||||
(i + 1) * context_parallel_degree,
|
||||
i * sequence_parallel_degree,
|
||||
(i + 1) * sequence_parallel_degree,
|
||||
)
|
||||
)
|
||||
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||
@@ -103,7 +113,9 @@ def register_ring_attn(
|
||||
if rank in ring_attn_ranks:
|
||||
set_ring_attn_group(group)
|
||||
|
||||
LOG.info(f"Context parallel group assignments: {group_assignments}")
|
||||
# Log the GPU group assignments
|
||||
if rank == 0:
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
|
||||
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
@@ -138,7 +150,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
|
||||
|
||||
|
||||
def patch_prepare_data_loader():
|
||||
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the CP degree.
|
||||
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
|
||||
|
||||
Raies:
|
||||
RuntimeError: If source code to patch does not exist.
|
||||
@@ -164,15 +176,15 @@ def patch_prepare_data_loader():
|
||||
patched_function = namespace["prepare_data_loader"]
|
||||
|
||||
accelerate.data_loader.prepare_data_loader = patched_function
|
||||
LOG.info("Patched accelerate.data_loader.prepare_data_loader for CP support")
|
||||
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
||||
|
||||
|
||||
def patch_prepare_device_mesh(context_parallel_degree: int):
|
||||
def patch_prepare_device_mesh(sequence_parallel_degree: int):
|
||||
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
||||
that includes context parallelism with the specified degree.
|
||||
that includes sequence parallelism with the specified degree.
|
||||
|
||||
Args:
|
||||
context_parallel_degree (int): The degree of context parallelism to use.
|
||||
sequence_parallel_degree (int): The degree of sequence parallelism to use.
|
||||
"""
|
||||
|
||||
def _prepare_device_mesh(self):
|
||||
@@ -187,11 +199,11 @@ def patch_prepare_device_mesh(context_parallel_degree: int):
|
||||
):
|
||||
return self.state.ds_device_mesh
|
||||
|
||||
# Create device mesh with context parallelism
|
||||
# Create device mesh with sequence parallelism
|
||||
world_size = dist.get_world_size()
|
||||
mesh_shape = (
|
||||
world_size // context_parallel_degree,
|
||||
context_parallel_degree,
|
||||
world_size // sequence_parallel_degree,
|
||||
sequence_parallel_degree,
|
||||
)
|
||||
device_ids = list(range(world_size))
|
||||
|
||||
@@ -209,5 +221,5 @@ def patch_prepare_device_mesh(context_parallel_degree: int):
|
||||
|
||||
LOG.info(
|
||||
"Successfully patched Accelerator._prepare_device_mesh "
|
||||
f"with context_parallel_degree={context_parallel_degree}"
|
||||
f"with sequence_parallel_degree={sequence_parallel_degree}"
|
||||
)
|
||||
|
||||
164
src/axolotl/telemetry/callbacks.py
Normal file
164
src/axolotl/telemetry/callbacks.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Trainer callbacks for reporting runtime metrics at regular intervals."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.telemetry.runtime_metrics import RuntimeMetricsTracker
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
TIME_SINCE_LAST = 30
|
||||
|
||||
|
||||
class TelemetryCallback(TrainerCallback):
|
||||
"""
|
||||
Trainer callback for tracking and reporting runtime metrics.
|
||||
|
||||
This callback tracks training progress, runtime, and memory usage,
|
||||
sending telemetry at configurable intervals.
|
||||
"""
|
||||
|
||||
report_interval_steps: int = 100
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the metrics callback."""
|
||||
self.tracker = RuntimeMetricsTracker()
|
||||
self.telemetry_manager = TelemetryManager.get_instance()
|
||||
self.current_epoch = -1
|
||||
self.start_time = time.time()
|
||||
self.last_report_time = None
|
||||
self.last_report_step = 0
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Handle training start."""
|
||||
self.telemetry_manager.send_event(event_type="train-start")
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_train_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Handle training end."""
|
||||
# Send training completion event
|
||||
self.telemetry_manager.send_event(
|
||||
event_type="train-end",
|
||||
properties=self._extract_last_metrics(state)
|
||||
| self.tracker.metrics.to_dict(),
|
||||
)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_epoch_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Handle epoch start."""
|
||||
self.current_epoch += 1
|
||||
self.tracker.start_epoch(self.current_epoch)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_epoch_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Handle epoch end."""
|
||||
self.tracker.end_epoch(self.current_epoch)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Handle step end."""
|
||||
step = state.global_step
|
||||
self.tracker.update_step(step)
|
||||
|
||||
# Check if we should report metrics
|
||||
should_report = (
|
||||
step % self.report_interval_steps == 0
|
||||
or step == 1 # Always report first step
|
||||
or step - self.last_report_step >= self.report_interval_steps
|
||||
)
|
||||
|
||||
if should_report:
|
||||
current_time = time.time()
|
||||
if self.last_report_time is not None:
|
||||
time_since_last_report = current_time - self.last_report_time
|
||||
else:
|
||||
time_since_last_report = current_time - self.start_time
|
||||
steps_since_last_report = step - self.last_report_step
|
||||
|
||||
# Only report if enough time has passed
|
||||
if (
|
||||
step == 1
|
||||
or time_since_last_report >= TIME_SINCE_LAST
|
||||
or steps_since_last_report >= self.report_interval_steps
|
||||
):
|
||||
# Calculate steps per second for this interval
|
||||
if time_since_last_report > 0 and steps_since_last_report > 0:
|
||||
steps_per_second = steps_since_last_report / time_since_last_report
|
||||
else:
|
||||
steps_per_second = 0
|
||||
|
||||
# Update memory metrics
|
||||
self.tracker.update_memory_metrics()
|
||||
|
||||
# Prepare metrics to report
|
||||
metrics = self._extract_last_metrics(state) | {
|
||||
"step": step,
|
||||
"epoch": self.current_epoch,
|
||||
"progress": state.epoch, # Fractional epoch progress
|
||||
"steps_per_second": steps_per_second,
|
||||
"elapsed_time": current_time - self.start_time,
|
||||
"time_since_last_report": time_since_last_report,
|
||||
}
|
||||
|
||||
# Add memory metrics
|
||||
memory_metrics = self.tracker.get_memory_metrics()
|
||||
metrics.update({"memory": memory_metrics})
|
||||
|
||||
# Send telemetry
|
||||
self.telemetry_manager.send_event(
|
||||
event_type="train-progress", properties=metrics
|
||||
)
|
||||
|
||||
# Update last report time and step
|
||||
self.last_report_time = current_time
|
||||
self.last_report_step = step
|
||||
|
||||
def _extract_last_metrics(self, state: TrainerState) -> dict:
|
||||
"""Extract last loss and learning_rate from log history."""
|
||||
if not state.log_history:
|
||||
return {"loss": 0, "learning_rate": 0}
|
||||
|
||||
last_log = state.log_history[-1]
|
||||
return {
|
||||
"loss": last_log.get("loss", 0),
|
||||
"learning_rate": last_log.get("learning_rate", 0),
|
||||
}
|
||||
160
src/axolotl/telemetry/errors.py
Normal file
160
src/axolotl/telemetry/errors.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Telemetry utilities for exception and traceback information."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from functools import wraps
|
||||
from inspect import getmodule
|
||||
from typing import Any, Callable
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
ERROR_HANDLED = False
|
||||
|
||||
|
||||
def sanitize_stack_trace(stack_trace: str) -> str:
|
||||
"""
|
||||
Remove personal information from stack trace messages while keeping Python package codepaths.
|
||||
|
||||
This function identifies Python packages by looking for common patterns in virtual environment
|
||||
and site-packages directories, preserving the package path while removing user-specific paths.
|
||||
|
||||
Args:
|
||||
stack_trace: The original stack trace string.
|
||||
|
||||
Returns:
|
||||
A sanitized version of the stack trace with Python package paths preserved.
|
||||
"""
|
||||
# Split the stack trace into lines to process each file path separately
|
||||
lines = stack_trace.split("\n")
|
||||
sanitized_lines = []
|
||||
|
||||
# Regular expression to find file paths in the stack trace
|
||||
path_pattern = re.compile(r'(?:File ")(.*?)(?:")')
|
||||
|
||||
# Regular expression to identify paths in site-packages or dist-packages
|
||||
# This matches path segments like "site-packages/package_name" or "dist-packages/package_name"
|
||||
site_packages_pattern = re.compile(
|
||||
r"(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
|
||||
)
|
||||
|
||||
# Additional common virtual environment patterns
|
||||
venv_lib_pattern = re.compile(
|
||||
r"(?:lib|Lib)[/\\](?:python\d+(?:\.\d+)?[/\\])?(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
|
||||
)
|
||||
|
||||
for line in lines:
|
||||
# Check if this line contains a file path
|
||||
path_match = path_pattern.search(line)
|
||||
|
||||
if path_match:
|
||||
full_path = path_match.group(1)
|
||||
sanitized_path = ""
|
||||
|
||||
# Try to match site-packages pattern
|
||||
site_packages_match = site_packages_pattern.search(full_path)
|
||||
venv_lib_match = venv_lib_pattern.search(full_path)
|
||||
|
||||
if site_packages_match:
|
||||
# Find the index where the matched pattern starts
|
||||
idx = full_path.find("site-packages")
|
||||
if idx == -1:
|
||||
idx = full_path.find("dist-packages")
|
||||
|
||||
# Keep from 'site-packages' onward
|
||||
if idx >= 0:
|
||||
sanitized_path = full_path[idx:]
|
||||
elif venv_lib_match:
|
||||
# For other virtual environment patterns, find the package directory
|
||||
match_idx = venv_lib_match.start(1)
|
||||
if match_idx > 0:
|
||||
# Keep from the package name onward
|
||||
package_name = venv_lib_match.group(1)
|
||||
idx = full_path.rfind(
|
||||
package_name, 0, match_idx + len(package_name)
|
||||
)
|
||||
if idx >= 0:
|
||||
sanitized_path = full_path[idx:]
|
||||
|
||||
# If we couldn't identify a package pattern but path contains 'axolotl'
|
||||
elif "axolotl" in full_path:
|
||||
idx = full_path.rfind("axolotl")
|
||||
if idx >= 0:
|
||||
sanitized_path = full_path[idx:]
|
||||
|
||||
# Apply the sanitization to the line
|
||||
if sanitized_path:
|
||||
line = line.replace(full_path, sanitized_path)
|
||||
else:
|
||||
# If we couldn't identify a package pattern, just keep the filename
|
||||
filename = os.path.basename(full_path)
|
||||
if filename:
|
||||
line = line.replace(full_path, filename)
|
||||
else:
|
||||
line = line.replace(full_path, "")
|
||||
|
||||
sanitized_lines.append(line)
|
||||
|
||||
return "\n".join(sanitized_lines)
|
||||
|
||||
|
||||
def send_errors(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to send exception info in a function. If an exception is raised, we send
|
||||
telemetry containing the stack trace and error message.
|
||||
|
||||
If an error occurs in a decorated function that is called by another decorated
|
||||
function, we'll only send telemetry corresponding to the lower-level function.
|
||||
|
||||
Args:
|
||||
func: Function to decorate.
|
||||
|
||||
Returns:
|
||||
Decorated function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
telemetry_manager = TelemetryManager.get_instance()
|
||||
|
||||
if not telemetry_manager.enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as exception:
|
||||
# Only track if we're not already handling an error. This prevents us from
|
||||
# capturing an error more than once in nested decorated function calls.
|
||||
global ERROR_HANDLED # pylint: disable=global-statement
|
||||
if not ERROR_HANDLED:
|
||||
ERROR_HANDLED = True
|
||||
|
||||
# Get function module path
|
||||
module = getmodule(func)
|
||||
module_path = (
|
||||
f"{module.__name__}.{func.__name__}" if module else func.__name__
|
||||
)
|
||||
|
||||
# Get stack trace
|
||||
stack_trace = "".join(
|
||||
traceback.format_exception(
|
||||
type(exception), exception, exception.__traceback__
|
||||
)
|
||||
)
|
||||
stack_trace = sanitize_stack_trace(stack_trace)
|
||||
|
||||
# Send error telemetry
|
||||
telemetry_manager.send_event(
|
||||
event_type=f"{module_path}-error",
|
||||
properties={
|
||||
"exception": str(exception),
|
||||
"stack_trace": stack_trace,
|
||||
},
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
417
src/axolotl/telemetry/manager.py
Normal file
417
src/axolotl/telemetry/manager.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""Telemetry manager and associated utilities."""
|
||||
|
||||
import atexit
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import posthog
|
||||
import psutil
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
POSTHOG_HOST = "https://app.posthog.com"
|
||||
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
|
||||
|
||||
OPT_IN_WARNING_SLEEP_SECONDS = 10
|
||||
OPT_IN_WARNING = (
|
||||
"\nTelemetry is currently disabled by default. If you'd like to help improve "
|
||||
"Axolotl, consider enabling it by setting AXOLOTL_DO_NOT_TRACK=0 in your environment.\n\n"
|
||||
"Telemetry data helps us understand:\n"
|
||||
"- Which features are most used\n"
|
||||
"- What hardware configurations to prioritize\n"
|
||||
"- Where users encounter errors\n\n"
|
||||
"Personally identifiable information (PII) is not collected.\n\n"
|
||||
"To remove this warning, explicitly set AXOLOTL_DO_NOT_TRACK=0 (enable telemetry) "
|
||||
"or AXOLOTL_DO_NOT_TRACK=1 (explicitly disable telemetry).\n\n"
|
||||
"Note: Telemetry will move to an opt-out in a later release.\n\n"
|
||||
"For details, see: https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html\n\n"
|
||||
f"Sleeping for {OPT_IN_WARNING_SLEEP_SECONDS}s..."
|
||||
)
|
||||
|
||||
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
|
||||
|
||||
# NOTE: Need to keep these up to date with any config schema changes
|
||||
FIELDS_TO_REDACT = {
|
||||
"base_model",
|
||||
"tokenizer_config",
|
||||
"base_model_config",
|
||||
"pretraining_dataset", # NOTE: this field may be a string or a dictionary
|
||||
"resume_from_checkpoint",
|
||||
"hub_model_id",
|
||||
}
|
||||
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"}
|
||||
PATH_INDICATORS = {"path", "dir"}
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
RELEVANT_PACKAGES = {
|
||||
"torch",
|
||||
"transformers",
|
||||
"trl",
|
||||
"datasets",
|
||||
"peft",
|
||||
"bitsandbytes",
|
||||
"accelerate",
|
||||
"optimum",
|
||||
"deepspeed",
|
||||
"ray",
|
||||
"axolotl",
|
||||
"triton",
|
||||
"mamba-ssm",
|
||||
"flash-attn",
|
||||
"xformers",
|
||||
"autoawq",
|
||||
"tokenizers",
|
||||
"sentencepiece",
|
||||
"torchao",
|
||||
"lm_eval",
|
||||
}
|
||||
|
||||
|
||||
def is_main_process() -> bool:
|
||||
"""
|
||||
Check whether we're running in the main process.
|
||||
|
||||
Note:
|
||||
We're using this function instead of `torch.utils.distributed.is_main_process`
|
||||
causes issues with DeepSpeed world_size since. This function avoids that issue
|
||||
by checking env vars that are set by various launchers.
|
||||
|
||||
Returns:
|
||||
Whether we're running in the main process.
|
||||
"""
|
||||
# If PyTorch distributed is already initialized, use it
|
||||
if torch.distributed.is_initialized():
|
||||
return torch.distributed.get_rank() == 0
|
||||
|
||||
# Otherwise check environment variables for global rank
|
||||
# NOTE: need to verify this in SLURM / OpenMPI environments
|
||||
global_rank = int(
|
||||
os.environ.get(
|
||||
"RANK",
|
||||
os.environ.get(
|
||||
"GLOBAL_RANK",
|
||||
os.environ.get(
|
||||
"SLURM_PROCID",
|
||||
os.environ.get(
|
||||
"OMPI_COMM_WORLD_RANK",
|
||||
"0",
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return global_rank == 0
|
||||
|
||||
|
||||
class TelemetryManager:
|
||||
"""Manages telemetry collection and transmission"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
"""
|
||||
Telemetry manager constructor. Creates the singleton instance of this class if
|
||||
it doesn't already exist.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super(TelemetryManager, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Telemetry manager initializer"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.enabled = self._check_telemetry_enabled()
|
||||
|
||||
if self.enabled:
|
||||
self.run_id = str(uuid.uuid4())
|
||||
self.whitelist = self._load_whitelist()
|
||||
|
||||
try:
|
||||
self.system_info = self._get_system_info()
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"Error during system info collection: {e}")
|
||||
self.system_info = None
|
||||
|
||||
self._init_posthog()
|
||||
|
||||
# Register shutdown method to flush posthog telemetry
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "TelemetryManager":
|
||||
if cls._instance is None:
|
||||
cls._instance = TelemetryManager()
|
||||
|
||||
return cls._instance
|
||||
|
||||
def _check_telemetry_enabled(self) -> bool:
|
||||
"""
|
||||
Check if telemetry is enabled based on environment variables. We also check
|
||||
whether this is the main process (for the distributed setting and to avoid
|
||||
sending duplicate PostHog events per GPU).
|
||||
|
||||
Note: This is disabled by default on an opt-in basis. Set
|
||||
`AXOLOTL_DO_NOT_TRACK=0` to enable telemetry. We plan to move to an opt-out
|
||||
model in a later release. For more details, see
|
||||
https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
|
||||
|
||||
Returns:
|
||||
Boolean denoting whether telemetry is enabled or not.
|
||||
"""
|
||||
# Parse relevant env vars
|
||||
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
|
||||
do_not_track = os.getenv("DO_NOT_TRACK")
|
||||
|
||||
# Default to disabled (opt-in model for initial release)
|
||||
if axolotl_do_not_track is None or axolotl_do_not_track.lower() not in (
|
||||
"0",
|
||||
"1",
|
||||
"false",
|
||||
"true",
|
||||
):
|
||||
# Print opt-in info message for main process only
|
||||
if is_main_process():
|
||||
LOG.warning(OPT_IN_WARNING)
|
||||
time.sleep(OPT_IN_WARNING_SLEEP_SECONDS)
|
||||
|
||||
return False
|
||||
|
||||
# Only rank 0 will send telemetry
|
||||
if not is_main_process():
|
||||
return False
|
||||
|
||||
if do_not_track is None:
|
||||
do_not_track = "0"
|
||||
|
||||
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
|
||||
enabled = axolotl_do_not_track.lower() not in (
|
||||
"1",
|
||||
"true",
|
||||
) and do_not_track.lower() not in ("1", "true")
|
||||
|
||||
return enabled
|
||||
|
||||
def _load_whitelist(self) -> dict:
|
||||
"""Load HuggingFace Hub organization whitelist"""
|
||||
with open(WHITELIST_PATH, encoding="utf-8") as f:
|
||||
whitelist = yaml.safe_load(f)
|
||||
|
||||
# Send org strings to lowercase since model names are case insensitive
|
||||
whitelist["organizations"] = {
|
||||
org.lower() for org in whitelist["organizations"]
|
||||
}
|
||||
|
||||
return whitelist
|
||||
|
||||
def _is_whitelisted(self, value: str) -> bool:
|
||||
"""
|
||||
Check if model / dataset / etc. org is in whitelist.
|
||||
|
||||
Args:
|
||||
value: Value for one of `axolotl.telemetry.manager.FIELDS_WITH_ORGS`
|
||||
("base_model", etc.).
|
||||
|
||||
Returns:
|
||||
Boolean indicating whitelist membership.
|
||||
"""
|
||||
# NOTE: This membership-checking logic can be improved.
|
||||
# What happens when a local model path matches a whitelisted org?
|
||||
parts = value.split("/")
|
||||
if len(parts) < 2:
|
||||
return False
|
||||
org = parts[0]
|
||||
whitelisted = org.lower() in self.whitelist["organizations"]
|
||||
|
||||
return whitelisted
|
||||
|
||||
def _init_posthog(self):
|
||||
"""Initialize PostHog client"""
|
||||
posthog.host = POSTHOG_HOST
|
||||
posthog.project_api_key = POSTHOG_WRITE_KEY
|
||||
|
||||
def _redact_paths(self, properties: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Redact properties to remove any paths, so as to avoid inadvertently collecting
|
||||
private or personally identifiable information (PII). We also remove
|
||||
information related to Wandb, MLflow, etc. configuration.
|
||||
|
||||
Args:
|
||||
properties: Dictionary of properties to redact.
|
||||
|
||||
Returns:
|
||||
Properties dictionary with redaction applied.
|
||||
"""
|
||||
if not properties:
|
||||
return {}
|
||||
|
||||
def redact_value(value: Any, key: str = "") -> Any:
|
||||
"""Recursively sanitize values, redacting those with path-like keys"""
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
# Other redaction special cases
|
||||
if (
|
||||
key in FIELDS_TO_REDACT
|
||||
or any(prefix in key for prefix in PREFIXES_TO_REDACT)
|
||||
or any(indicator in key.lower() for indicator in PATH_INDICATORS)
|
||||
):
|
||||
# Fields with whitelisted orgs don't need to be redacted
|
||||
if not self._is_whitelisted(value):
|
||||
return "[REDACTED]"
|
||||
|
||||
# Handle nested values
|
||||
if isinstance(value, dict):
|
||||
return {k: redact_value(v, k) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [redact_value(item) for item in value]
|
||||
|
||||
return value
|
||||
|
||||
# Create new dict with redacted values
|
||||
redacted = {k: redact_value(v, k) for k, v in properties.items()}
|
||||
|
||||
return redacted
|
||||
|
||||
def _get_system_info(self) -> dict[str, Any]:
|
||||
"""Collect system information for various hardware accelerators"""
|
||||
gpu_info = []
|
||||
accelerator_type = "none"
|
||||
|
||||
# NVIDIA GPUs
|
||||
if torch.cuda.is_available():
|
||||
accelerator_type = "cuda"
|
||||
for i in range(torch.cuda.device_count()):
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": torch.cuda.get_device_name(i),
|
||||
"memory": torch.cuda.get_device_properties(i).total_memory,
|
||||
}
|
||||
)
|
||||
|
||||
# AMD GPUs
|
||||
elif hasattr(torch, "hip") and torch.hip.is_available():
|
||||
accelerator_type = "hip"
|
||||
for i in range(torch.hip.device_count()):
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": torch.hip.get_device_name(i),
|
||||
"memory": (
|
||||
torch.hip.get_device_properties(i).total_memory
|
||||
if hasattr(torch.hip, "get_device_properties")
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Apple Silicon
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
accelerator_type = "mps"
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": "Apple Silicon",
|
||||
# NOTE: this is memory allocated to this process, not total memory
|
||||
"memory": torch.mps.driver_allocated_memory(),
|
||||
}
|
||||
)
|
||||
|
||||
# Intel GPUs
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
accelerator_type = "xpu"
|
||||
for i in range(torch.xpu.device_count()):
|
||||
memory = None
|
||||
if hasattr(torch.xpu, "get_device_properties"):
|
||||
memory = torch.xpu.get_device_properties(i).total_memory
|
||||
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": torch.xpu.get_device_name(i),
|
||||
"memory": memory,
|
||||
}
|
||||
)
|
||||
|
||||
# NPUs
|
||||
elif hasattr(torch, "npu") and torch.npu.is_available():
|
||||
accelerator_type = "npu"
|
||||
for i in range(torch.npu.device_count()):
|
||||
memory = None
|
||||
if hasattr(torch.npu, "get_device_properties"):
|
||||
memory = torch.npu.get_device_properties(i).total_memory
|
||||
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": torch.npu.get_device_name(i),
|
||||
"memory": memory,
|
||||
}
|
||||
)
|
||||
|
||||
# Get relevant package versions
|
||||
installed_packages = {}
|
||||
for package in RELEVANT_PACKAGES:
|
||||
try:
|
||||
version = importlib.metadata.version(package)
|
||||
installed_packages[f"{package}_version"] = version
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
pass
|
||||
|
||||
return {
|
||||
"os": platform.system(),
|
||||
"python_version": platform.python_version(),
|
||||
"cpu_count": psutil.cpu_count(),
|
||||
"memory_total": psutil.virtual_memory().total,
|
||||
"accelerator_type": accelerator_type,
|
||||
"accelerator_count": len(gpu_info),
|
||||
"accelerator_info": gpu_info,
|
||||
**installed_packages,
|
||||
}
|
||||
|
||||
def send_event(self, event_type: str, properties: dict[str, Any] | None = None):
|
||||
"""Send a telemetry event"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
if properties is None:
|
||||
properties = {}
|
||||
|
||||
# Sanitize properties to remove PII
|
||||
properties = self._redact_paths(properties)
|
||||
|
||||
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage
|
||||
try:
|
||||
# Send event via PostHog
|
||||
posthog.capture(
|
||||
distinct_id=self.run_id,
|
||||
event=event_type,
|
||||
properties=properties,
|
||||
disable_geoip=True,
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"Failed to send telemetry event: {e}")
|
||||
|
||||
# Additionally, send system info telemetry when loading config.
|
||||
# NOTE: Is this the best place for this?
|
||||
if event_type == "config-loaded":
|
||||
self.send_system_info()
|
||||
|
||||
def send_system_info(self):
|
||||
"""Helper method for sending system info"""
|
||||
if self.system_info is not None:
|
||||
self.send_event(event_type="system-info", properties=self.system_info)
|
||||
|
||||
def shutdown(self):
|
||||
"""Ensure all queued events are processed before shutdown"""
|
||||
if self.enabled:
|
||||
posthog.flush()
|
||||
209
src/axolotl/telemetry/runtime_metrics.py
Normal file
209
src/axolotl/telemetry/runtime_metrics.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Telemetry utilities for runtime and memory metrics."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeMetrics:
|
||||
"""Container for runtime metrics to be tracked throughout training."""
|
||||
|
||||
# Timing metrics
|
||||
start_time: float
|
||||
epoch_start_times: dict[int, float] = field(init=False)
|
||||
epoch_end_times: dict[int, float] = field(init=False)
|
||||
|
||||
# Memory metrics
|
||||
peak_cpu_memory: int = 0
|
||||
peak_gpu_memory: dict[int, int] = field(init=False)
|
||||
|
||||
# Progress metrics
|
||||
total_steps: int = 0
|
||||
current_epoch: int = 0
|
||||
current_step: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize empty metric mappings."""
|
||||
self.epoch_start_times = {}
|
||||
self.epoch_end_times = {}
|
||||
self.peak_gpu_memory = {}
|
||||
|
||||
@property
|
||||
def elapsed_time(self) -> float:
|
||||
"""Calculate total elapsed time in seconds."""
|
||||
return time.time() - self.start_time
|
||||
|
||||
def epoch_time(self, epoch: int) -> float | None:
|
||||
"""Calculate time taken for a specific epoch in seconds."""
|
||||
if epoch in self.epoch_start_times and epoch in self.epoch_end_times:
|
||||
return self.epoch_end_times[epoch] - self.epoch_start_times[epoch]
|
||||
|
||||
return None
|
||||
|
||||
def average_epoch_time(self) -> float | None:
|
||||
"""Calculate average time per epoch in seconds."""
|
||||
completed_epochs = [
|
||||
epoch for epoch in self.epoch_start_times if epoch in self.epoch_end_times
|
||||
]
|
||||
if not completed_epochs:
|
||||
return None
|
||||
|
||||
total_time = 0.0
|
||||
for epoch in completed_epochs:
|
||||
epoch_time = self.epoch_time(epoch)
|
||||
if epoch_time is not None: # Check to avoid mypy warning
|
||||
total_time += epoch_time
|
||||
|
||||
return total_time / len(completed_epochs)
|
||||
|
||||
def steps_per_second(self) -> float | None:
|
||||
"""Calculate average steps per second across all training."""
|
||||
if self.total_steps == 0 or self.elapsed_time == 0:
|
||||
return None
|
||||
|
||||
return self.total_steps / self.elapsed_time
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert metrics to a dictionary for telemetry reporting."""
|
||||
metrics = {
|
||||
"total_time_seconds": self.elapsed_time,
|
||||
"total_steps": self.total_steps,
|
||||
"steps_per_second": self.steps_per_second(),
|
||||
"epochs_completed": len(
|
||||
[
|
||||
epoch
|
||||
for epoch in self.epoch_start_times
|
||||
if epoch in self.epoch_end_times
|
||||
]
|
||||
),
|
||||
"peak_cpu_memory_bytes": self.peak_cpu_memory,
|
||||
}
|
||||
|
||||
# Add per-epoch timing if available
|
||||
epoch_times: dict[str, float] = {}
|
||||
for epoch in sorted(self.epoch_end_times.keys()):
|
||||
time_taken = self.epoch_time(epoch)
|
||||
if time_taken is not None:
|
||||
epoch_times[f"epoch_{epoch}_seconds"] = time_taken
|
||||
|
||||
if epoch_times:
|
||||
metrics["epoch_times"] = epoch_times # type: ignore
|
||||
metrics["average_epoch_time_seconds"] = self.average_epoch_time()
|
||||
|
||||
# Add GPU memory metrics if available
|
||||
if self.peak_gpu_memory:
|
||||
gpu_metrics: dict[str, int] = {}
|
||||
for gpu_id, memory in self.peak_gpu_memory.items():
|
||||
gpu_metrics[f"gpu_{gpu_id}_peak_memory_bytes"] = memory
|
||||
metrics["gpu_memory"] = gpu_metrics # type: ignore
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
class RuntimeMetricsTracker:
|
||||
"""Tracker for runtime metrics during training."""
|
||||
|
||||
update_interval = 100
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the runtime metrics tracker."""
|
||||
self.metrics = RuntimeMetrics(start_time=time.time())
|
||||
self.telemetry_manager = TelemetryManager.get_instance()
|
||||
|
||||
def start_epoch(self, epoch: int):
|
||||
"""Record the start of a new epoch."""
|
||||
self.metrics.current_epoch = epoch
|
||||
self.metrics.epoch_start_times[epoch] = time.time()
|
||||
self.update_memory_metrics()
|
||||
|
||||
def end_epoch(self, epoch: int):
|
||||
"""Record the end of an epoch."""
|
||||
self.metrics.epoch_end_times[epoch] = time.time()
|
||||
|
||||
def update_step(self, step: int):
|
||||
"""Update the current step count."""
|
||||
self.metrics.current_step = step
|
||||
self.metrics.total_steps += 1
|
||||
|
||||
# Periodically update memory metrics
|
||||
if step % self.update_interval == 0:
|
||||
self.update_memory_metrics()
|
||||
|
||||
def _get_allocated_memory(self) -> dict[int, int]:
|
||||
"""
|
||||
Helper function for getting accelerator-agnostic allocated memory.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping device IDs to allocated memory in bytes
|
||||
"""
|
||||
memory_used: dict[int, int] = {}
|
||||
|
||||
# NVIDIA GPUs
|
||||
if torch.cuda.is_available():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
memory_used[i] = torch.cuda.memory_allocated(i)
|
||||
|
||||
# AMD GPUs
|
||||
elif hasattr(torch, "hip") and torch.hip.is_available():
|
||||
for i in range(torch.hip.device_count()):
|
||||
if hasattr(torch.hip, "memory_allocated"):
|
||||
memory_used[i] = torch.hip.memory_allocated(i)
|
||||
|
||||
# Apple Silicon
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
# MPS doesn't have per-device memory stats since there's only one device
|
||||
if hasattr(torch.mps, "current_allocated_memory"):
|
||||
memory_used[0] = torch.mps.current_allocated_memory()
|
||||
|
||||
# Intel GPUs
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
for i in range(torch.xpu.device_count()):
|
||||
if hasattr(torch.xpu, "memory_allocated"):
|
||||
memory_used[i] = torch.xpu.memory_allocated(i)
|
||||
|
||||
# NPUs
|
||||
elif hasattr(torch, "npu") and torch.npu.is_available():
|
||||
for i in range(torch.npu.device_count()):
|
||||
if hasattr(torch.npu, "memory_allocated"):
|
||||
memory_used[i] = torch.npu.memory_allocated(i)
|
||||
|
||||
return memory_used
|
||||
|
||||
def update_memory_metrics(self):
|
||||
"""Update peak memory usage metrics."""
|
||||
# CPU memory
|
||||
cpu_memory = psutil.Process().memory_info().rss
|
||||
self.metrics.peak_cpu_memory = max(self.metrics.peak_cpu_memory, cpu_memory)
|
||||
|
||||
# GPU memory (if available)
|
||||
memory_used = self._get_allocated_memory()
|
||||
for i, memory in memory_used.items():
|
||||
self.metrics.peak_gpu_memory[i] = max(
|
||||
self.metrics.peak_gpu_memory.get(i, 0), memory
|
||||
)
|
||||
|
||||
def get_memory_metrics(self) -> dict[str, Any]:
|
||||
"""Get the current memory metrics as a dictionary."""
|
||||
memory_metrics = {
|
||||
"cpu_memory_bytes": psutil.Process().memory_info().rss,
|
||||
"peak_cpu_memory_bytes": self.metrics.peak_cpu_memory,
|
||||
}
|
||||
|
||||
# GPU memory (if available)
|
||||
memory_used = self._get_allocated_memory()
|
||||
for i, memory in memory_used.items():
|
||||
memory_metrics[f"gpu_{i}_memory_bytes"] = memory
|
||||
memory_metrics[f"gpu_{i}_peak_memory_bytes"] = (
|
||||
self.metrics.peak_gpu_memory.get(i, 0)
|
||||
)
|
||||
|
||||
return memory_metrics
|
||||
17
src/axolotl/telemetry/whitelist.yaml
Normal file
17
src/axolotl/telemetry/whitelist.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
organizations:
|
||||
- "axolotl-ai-co"
|
||||
- "meta-llama"
|
||||
- "huggingface"
|
||||
- "nvidia"
|
||||
- "facebook"
|
||||
- "google"
|
||||
- "microsoft"
|
||||
- "deepseek-ai"
|
||||
- "HuggingFaceTB"
|
||||
- "mistralai"
|
||||
- "Qwen"
|
||||
- "unsloth"
|
||||
- "NousResearch"
|
||||
- "allenai"
|
||||
- "amd"
|
||||
- "tiiuae"
|
||||
@@ -25,13 +25,16 @@ from axolotl.common.datasets import TrainDatasetMeta
|
||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||
fix_untrained_tokens,
|
||||
)
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders import (
|
||||
ModelLoader,
|
||||
load_processor,
|
||||
load_tokenizer,
|
||||
)
|
||||
from axolotl.utils.ctx_managers import ContextParallelContextManager
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
@@ -46,6 +49,9 @@ except ImportError:
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
|
||||
|
||||
def setup_model_and_tokenizer(
|
||||
cfg: DictDefault,
|
||||
@@ -63,7 +69,10 @@ def setup_model_and_tokenizer(
|
||||
`None`), and processor (if multimodal, else `None`).
|
||||
"""
|
||||
# Load tokenizer
|
||||
LOG.debug(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
main_process_only=True,
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
# Load processor for multimodal models if needed
|
||||
@@ -82,6 +91,14 @@ def setup_model_and_tokenizer(
|
||||
if model.generation_config is not None:
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
TELEMETRY_MANAGER.send_event(
|
||||
event_type="model-load", properties=model.config.to_dict()
|
||||
)
|
||||
if peft_config:
|
||||
TELEMETRY_MANAGER.send_event(
|
||||
event_type="peft-config-load", properties=peft_config.to_dict()
|
||||
)
|
||||
|
||||
# Apply freezing if specified
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||
@@ -147,7 +164,7 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
|
||||
|
||||
|
||||
def setup_signal_handler(
|
||||
cfg: DictDefault, model: PeftModel | PreTrainedModel, safe_serialization: bool
|
||||
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
|
||||
):
|
||||
"""
|
||||
Set up signal handler for graceful termination.
|
||||
@@ -201,20 +218,15 @@ def execute_training(
|
||||
)
|
||||
)
|
||||
|
||||
if cfg.context_parallel_degree > 1 and not cfg.sdp_attention:
|
||||
# Models to enter context parallel manager for
|
||||
if cfg.sequence_parallel_degree > 1:
|
||||
models = [trainer.model]
|
||||
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
||||
models.append(trainer.ref_model)
|
||||
|
||||
# Attention backend
|
||||
backend = "sdp_attention" if cfg.sdp_attention else "flash_attention"
|
||||
|
||||
stack.enter_context(
|
||||
ContextParallelContextManager(
|
||||
SequenceParallelContextManager(
|
||||
models=models,
|
||||
backend=backend,
|
||||
context_parallel_degree=cfg.context_parallel_degree,
|
||||
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
ring_attn_func=cfg.ring_attn_func,
|
||||
heads_k_stride=cfg.heads_k_stride,
|
||||
@@ -228,7 +240,7 @@ def execute_training(
|
||||
def save_trained_model(
|
||||
cfg: DictDefault,
|
||||
trainer: Any,
|
||||
model: PeftModel | PreTrainedModel,
|
||||
model: PreTrainedModel,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
"""
|
||||
@@ -379,7 +391,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
|
||||
def save_initial_configs(
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model: PeftModel | PreTrainedModel,
|
||||
model: PreTrainedModel,
|
||||
peft_config: PeftConfig | None,
|
||||
processor: ProcessorMixin | None,
|
||||
):
|
||||
@@ -433,7 +445,7 @@ def setup_model_card(cfg: DictDefault):
|
||||
|
||||
def handle_untrained_tokens_fix(
|
||||
cfg: DictDefault,
|
||||
model: PeftModel | PreTrainedModel,
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
train_dataset: Dataset,
|
||||
safe_serialization: bool,
|
||||
@@ -476,7 +488,7 @@ def handle_untrained_tokens_fix(
|
||||
|
||||
|
||||
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
||||
Trainer,
|
||||
HFRLTrainerBuilder | HFCausalTrainerBuilder,
|
||||
PeftModel | PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PeftConfig | None,
|
||||
@@ -521,6 +533,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
||||
model_ref=model_ref,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
PLUGIN_MANAGER.post_trainer_create(cfg, trainer)
|
||||
|
||||
return (
|
||||
trainer,
|
||||
@@ -531,6 +544,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
||||
)
|
||||
|
||||
|
||||
@send_errors
|
||||
def train(
|
||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
|
||||
@@ -555,8 +569,11 @@ def train(
|
||||
processor,
|
||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_trainer_create(cfg, trainer)
|
||||
# Determine if we need to resume from a checkpoint
|
||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||
|
||||
# Configuration for saving
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
# Handle untrained tokens if configured
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
@@ -571,7 +588,6 @@ def train(
|
||||
setup_model_card(cfg)
|
||||
|
||||
# Execute the training
|
||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||
|
||||
# Save the trained model and cleanup
|
||||
@@ -579,7 +595,6 @@ def train(
|
||||
create_model_card(cfg, trainer)
|
||||
if not cfg.use_ray:
|
||||
cleanup_distributed()
|
||||
|
||||
plugin_manager.post_train(cfg, model)
|
||||
PLUGIN_MANAGER.post_train(cfg, model)
|
||||
|
||||
return model, tokenizer, trainer
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Init for context manager submodule."""
|
||||
"""Init for context manager submodule"""
|
||||
|
||||
from .context_parallel.manager import ContextParallelContextManager
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
__all__ = ["ContextParallelContextManager"]
|
||||
from .sequence_parallel import SequenceParallelContextManager
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
# BSD 3-Clause License
|
||||
|
||||
# Copyright 2024 Meta
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice,this list
|
||||
# of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice, this
|
||||
# list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its contributors may
|
||||
# be used to endorse or promote products derived from this software without specific
|
||||
# prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
|
||||
# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
|
||||
# SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
|
||||
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
|
||||
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
|
||||
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
|
||||
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
|
||||
# DAMAGE.
|
||||
|
||||
"""
|
||||
Distributed utils for SDPA context parallel implementation. Slightly modified from
|
||||
https://github.com/pytorch/torchtune/blob/2344509cf83bd886538fe3e8263e5145d1afb5c2/torchtune/training/_distributed.py.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
from typing import Callable, Generator, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
|
||||
def _get_sdpa_context() -> (
|
||||
Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]
|
||||
):
|
||||
"""
|
||||
Creates a context manager to confine to flash/efficient/cuDNN attention backends.
|
||||
|
||||
Returns:
|
||||
A context manager function that takes an optional context parallel context.
|
||||
"""
|
||||
|
||||
@contextlib.contextmanager
|
||||
def context(cp_context: Union[Generator[None, None, None], None] = None):
|
||||
with contextlib.ExitStack() as stack:
|
||||
if cp_context is not None:
|
||||
stack.enter_context(
|
||||
sdpa_kernel(
|
||||
[
|
||||
SDPBackend.FLASH_ATTENTION,
|
||||
SDPBackend.EFFICIENT_ATTENTION,
|
||||
SDPBackend.CUDNN_ATTENTION,
|
||||
]
|
||||
)
|
||||
)
|
||||
stack.enter_context(cp_context)
|
||||
|
||||
yield
|
||||
|
||||
return context
|
||||
|
||||
|
||||
def get_context_parallel_manager(
|
||||
*,
|
||||
world_mesh: torch.distributed.DeviceMesh,
|
||||
model: nn.Module,
|
||||
) -> Callable[[list[torch.Tensor]], Generator[None, None, None]]:
|
||||
"""
|
||||
Context manager for applying context parallelism to a model. In addition to applying the
|
||||
standard context manager to patch SDPA and shard model inputs and buffers along the sequence
|
||||
dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends.
|
||||
|
||||
Args:
|
||||
world_mesh: Global device mesh.
|
||||
model: Model to apply context parallelism to.
|
||||
|
||||
Returns:
|
||||
A context manager applying context parallelism if enabled is True. Otherwise a context manager
|
||||
disabling the math SDPA backend.
|
||||
|
||||
Raises:
|
||||
ValueError: if enabled is True but world_mesh does not contain a "cp" dimension
|
||||
"""
|
||||
|
||||
if "cp" not in world_mesh.mesh_dim_names:
|
||||
raise ValueError(
|
||||
"Context parallel is enabled but no context parallel device mesh is provided."
|
||||
)
|
||||
# TODO: context parallel for multimodal models requires extra work
|
||||
# if not isinstance(model, TransformerDecoder):
|
||||
# raise ValueError("Context parallel is only supported for text models")
|
||||
# model_buffers = list(model.buffers())
|
||||
|
||||
# def get_all_buffers(module, prefix=""):
|
||||
# buffers = {}
|
||||
# for name, buffer in module.named_buffers(recurse=False):
|
||||
# full_name = f"{prefix}.{name}" if prefix else name
|
||||
# buffers[full_name] = buffer
|
||||
|
||||
# for name, child in module.named_children():
|
||||
# child_prefix = f"{prefix}.{name}" if prefix else name
|
||||
# buffers.update(get_all_buffers(child, child_prefix))
|
||||
|
||||
# return buffers
|
||||
|
||||
# model_buffers = get_all_buffers(model)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def context(model_inputs: list[torch.Tensor]):
|
||||
# Create context parallel context if enabled
|
||||
cp_context = None
|
||||
if any([isinstance(input, BlockMask) for input in model_inputs]):
|
||||
raise ValueError(
|
||||
"Context parallel with flex attention is not yet supported"
|
||||
)
|
||||
set_rotate_method("allgather")
|
||||
|
||||
cp_context = context_parallel(
|
||||
world_mesh["cp"],
|
||||
# buffers=model_inputs + model_buffers,
|
||||
buffers=model_inputs,
|
||||
# buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers),
|
||||
buffer_seq_dims=[1] * len(model_inputs),
|
||||
no_restore_buffers=set(model_inputs),
|
||||
)
|
||||
|
||||
# Create and enter the train context with the optional cp_context
|
||||
sdpa_context = _get_sdpa_context()
|
||||
|
||||
with sdpa_context(cp_context):
|
||||
yield
|
||||
|
||||
return context
|
||||
@@ -1,216 +0,0 @@
|
||||
"""Module for Axolotl trainer context parallelism manager and utilities."""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Callable, Literal
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
from axolotl.monkeypatch.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
patch_prepare_data_loader,
|
||||
patch_prepare_device_mesh,
|
||||
register_ring_attn,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.context_parallel.distributed import (
|
||||
get_context_parallel_manager,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.context_parallel.utils import (
|
||||
AllGatherWithGrad,
|
||||
apply_context_parallelism,
|
||||
)
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
|
||||
class ContextParallelContextManager:
|
||||
"""Context manager for context parallelism operations.
|
||||
|
||||
This class provides a context that will automatically apply context parallelism
|
||||
during model forward passes using a pre-forward hook, and gather outputs from
|
||||
across the context parallelism group using a post-forward hook.
|
||||
|
||||
Args:
|
||||
models: List of models to apply context parallelism to pre- and post- forward
|
||||
hooks.
|
||||
backend: Which attention backend to use.
|
||||
context_parallel_degree: Number of processes to split sequences over.
|
||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||
ring_attn_func: Which ring attention function to use. Currently unused.
|
||||
heads_k_stride: Context parallelism K head stride size. Passed through to
|
||||
`varlen_llama3` `ring_flash_attn` implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
models: list[PreTrainedModel],
|
||||
backend: Literal["sdp_attention", "flash_attention"],
|
||||
context_parallel_degree: int,
|
||||
gradient_accumulation_steps: int,
|
||||
ring_attn_func: RingAttnFunc,
|
||||
heads_k_stride: int | None,
|
||||
):
|
||||
self.models = models
|
||||
self.backend = backend
|
||||
self.context_parallel_degree = context_parallel_degree
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
self.ring_attn_func = ring_attn_func
|
||||
self.heads_k_stride = heads_k_stride
|
||||
self._register_ring_attn()
|
||||
|
||||
# Store hook handles for removal
|
||||
self.hook_handles: list[RemovableHandle] = []
|
||||
|
||||
if self.backend == "flash_attention":
|
||||
# Set distributed info for local rank
|
||||
self.process_group = get_ring_attn_group()
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.local_world_size = dist.get_world_size(self.process_group)
|
||||
|
||||
# Create a partially applied version of the apply_context_parallelism function
|
||||
self.apply_context_parallelism = functools.partial(
|
||||
apply_context_parallelism,
|
||||
local_rank=self.local_rank,
|
||||
local_world_size=self.local_world_size,
|
||||
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
# Store original sequence length and padding information
|
||||
self.original_seq_len = 0
|
||||
self.pad_len = 0
|
||||
else:
|
||||
# SPDA device mesh init
|
||||
world_size = dist.get_world_size()
|
||||
mesh_shape = (
|
||||
world_size // self.context_parallel_degree,
|
||||
self.context_parallel_degree,
|
||||
)
|
||||
world_mesh = dist.DeviceMesh(
|
||||
"cuda",
|
||||
torch.tensor(list(range(world_size))).reshape(mesh_shape),
|
||||
mesh_dim_names=("dp", "cp"),
|
||||
)
|
||||
|
||||
# SDPA context parallel managers
|
||||
self.context_parallel_managers = []
|
||||
for model in models:
|
||||
ctx_manager = get_context_parallel_manager(
|
||||
world_mesh=world_mesh,
|
||||
model=model,
|
||||
)
|
||||
self.context_parallel_managers.append(ctx_manager)
|
||||
|
||||
def __enter__(self):
|
||||
self._register_model_hooks()
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Remove all hooks
|
||||
for handle in self.hook_handles:
|
||||
handle.remove()
|
||||
self.hook_handles = []
|
||||
|
||||
# TODO(djsaunde): Un-patch attention and accelerate functions (low priority)
|
||||
|
||||
def _register_ring_attn(self):
|
||||
if self.backend == "flash_attention":
|
||||
# Initialize ring attn for context parallelism
|
||||
register_ring_attn(
|
||||
context_parallel_degree=self.context_parallel_degree,
|
||||
heads_k_stride=self.heads_k_stride,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
# Patches for accelerate functionality
|
||||
patch_prepare_data_loader()
|
||||
patch_prepare_device_mesh(context_parallel_degree=self.context_parallel_degree)
|
||||
|
||||
def _register_model_hooks(self):
|
||||
# Forward pre-hook to apply context parallelism
|
||||
def cp_flash_pre_hook(_, args, kwargs):
|
||||
# Get parameter names from the model's forward function
|
||||
forward_params = list(
|
||||
inspect.signature(self.models[0].forward).parameters.keys()
|
||||
)
|
||||
|
||||
updated_kwargs = kwargs.copy()
|
||||
for i, arg in enumerate(args):
|
||||
if i < len(forward_params):
|
||||
updated_kwargs[forward_params[i]] = arg
|
||||
|
||||
# Any excess positional arguments are kept as-is
|
||||
remaining_args = args[len(forward_params) :]
|
||||
|
||||
# Apply context parallelism to updated kwargs
|
||||
updated_kwargs, self.original_seq_len, self.pad_len = (
|
||||
self.apply_context_parallelism(updated_kwargs)
|
||||
)
|
||||
|
||||
return remaining_args, updated_kwargs
|
||||
|
||||
# Forward post-hook to gather outputs
|
||||
def cp_flash_post_hook(_, __, output: ModelOutput) -> ModelOutput:
|
||||
# Gather the sharded outputs
|
||||
output = self._gather_outputs(output)
|
||||
|
||||
# Remove padding if it was added
|
||||
if self.pad_len > 0:
|
||||
for key, value in output.items():
|
||||
if isinstance(value, torch.Tensor) and value.dim() > 1:
|
||||
if value.size(1) == self.original_seq_len + self.pad_len:
|
||||
# Slice to remove padding
|
||||
output[key] = value[:, : self.original_seq_len].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
def make_sdpa_pre_hook(model_idx: int) -> Callable:
|
||||
def cp_sdpa_pre_hook(_, args, kwargs):
|
||||
# Get parameter names from the model's forward function
|
||||
forward_params = list(
|
||||
inspect.signature(self.models[0].forward).parameters.keys()
|
||||
)
|
||||
|
||||
updated_kwargs = kwargs.copy()
|
||||
for i, arg in enumerate(args):
|
||||
if i < len(forward_params):
|
||||
updated_kwargs[forward_params[i]] = arg
|
||||
|
||||
# Any excess positional arguments are kept as-is
|
||||
remaining_args = args[len(forward_params) :]
|
||||
|
||||
to_shard = {k: v for k, v in updated_kwargs.items() if v.ndim > 1}
|
||||
|
||||
with self.context_parallel_managers[model_idx](list(to_shard.values())):
|
||||
return remaining_args, updated_kwargs
|
||||
|
||||
return cp_sdpa_pre_hook
|
||||
|
||||
# Register both hooks
|
||||
for i, model in enumerate(self.models):
|
||||
if self.backend == "flash_attention":
|
||||
self.hook_handles.append(
|
||||
model.register_forward_pre_hook(cp_flash_pre_hook, with_kwargs=True)
|
||||
)
|
||||
self.hook_handles.append(
|
||||
model.register_forward_hook(cp_flash_post_hook)
|
||||
)
|
||||
else:
|
||||
self.hook_handles.append(
|
||||
model.register_forward_pre_hook(
|
||||
make_sdpa_pre_hook(i), with_kwargs=True
|
||||
)
|
||||
)
|
||||
|
||||
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
|
||||
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
|
||||
for key, value in output.items():
|
||||
if isinstance(value, torch.Tensor) and value.dim() > 1:
|
||||
output[key] = AllGatherWithGrad.apply(value, self.process_group)
|
||||
|
||||
return output
|
||||
@@ -1,15 +1,28 @@
|
||||
"""Utils for context parallel context manager."""
|
||||
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
from axolotl.monkeypatch.ring_attn.patch import update_ring_attn_params
|
||||
from axolotl.monkeypatch.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
patch_prepare_data_loader,
|
||||
patch_prepare_device_mesh,
|
||||
register_ring_attn,
|
||||
update_ring_attn_params,
|
||||
)
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
|
||||
# TODO(djsaunde): implement zigzag, stripe patterns here (and elsewhere) in this
|
||||
# module. Currently, we just focus on batch ring and varlen llama3 for simplicity.
|
||||
def apply_context_parallelism(
|
||||
def apply_sequence_parallelism(
|
||||
batch: dict[str, torch.Tensor],
|
||||
local_rank: int,
|
||||
local_world_size: int,
|
||||
@@ -17,15 +30,15 @@ def apply_context_parallelism(
|
||||
ring_attn_func: RingAttnFunc, # pylint: disable=unused-argument
|
||||
) -> tuple[dict[str, torch.Tensor], int, int]:
|
||||
"""
|
||||
Apply context parallelism slicing to a batch.
|
||||
Apply sequence parallelism slicing to a batch.
|
||||
|
||||
Special handling is implemented for integer logits_to_keep, which indicates
|
||||
to only keep the last N tokens in the input sequence during generation.
|
||||
to only keep the last N tokens in the sequence during generation.
|
||||
|
||||
Args:
|
||||
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
|
||||
local_rank: Local rank in the context parallel group.
|
||||
local_world_size: World size of the context parallel group.
|
||||
local_rank: Local rank in the sequence parallel group.
|
||||
local_world_size: World size of the sequence parallel group.
|
||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||
ring_attn_func: Which ring attention function to use. Currently unused, but
|
||||
related to above TODO.
|
||||
@@ -120,7 +133,7 @@ def apply_context_parallelism(
|
||||
# Update the total sequence length after padding
|
||||
total_seq_len = batch["input_ids"].size(1)
|
||||
|
||||
# Slice batch for context parallel
|
||||
# Slice batch for sequence parallel
|
||||
for key in batch:
|
||||
if not isinstance(batch[key], torch.Tensor) or batch[key].dim() <= 1:
|
||||
continue
|
||||
@@ -146,6 +159,144 @@ def apply_context_parallelism(
|
||||
return batch, original_seq_len, pad_len
|
||||
|
||||
|
||||
class SequenceParallelContextManager:
|
||||
"""Context manager for sequence parallelism operations.
|
||||
|
||||
This class provides a context that will automatically apply sequence parallelism
|
||||
during model forward passes using a pre-forward hook, and gather outputs from
|
||||
across the sequence parallelism group using a post-forward hook.
|
||||
|
||||
Args:
|
||||
models: List of models to apply sequence parallelism to pre- and post- forward
|
||||
hooks.
|
||||
sequence_parallel_degree: Number of processes to split sequences over.
|
||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||
ring_attn_func: Which ring attention function to use. Currently unused.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||
`varlen_llama3` `ring_flash_attn` implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
models: list[nn.Module],
|
||||
sequence_parallel_degree: int,
|
||||
gradient_accumulation_steps: int,
|
||||
ring_attn_func: RingAttnFunc,
|
||||
heads_k_stride: int | None,
|
||||
):
|
||||
self.models = models
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
self.ring_attn_func = ring_attn_func
|
||||
self.heads_k_stride = heads_k_stride
|
||||
self._register_ring_attn()
|
||||
|
||||
# Set distributed info for local rank
|
||||
self.process_group = get_ring_attn_group()
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.local_world_size = dist.get_world_size(self.process_group)
|
||||
|
||||
# Will store hook handles for removal
|
||||
self.hook_handles: list[RemovableHandle] = []
|
||||
|
||||
# Store original sequence length and padding information
|
||||
self.original_seq_len = 0
|
||||
self.pad_len = 0
|
||||
|
||||
# Create a partially applied version of the apply_sequence_parallelism function
|
||||
self.apply_sequence_parallelism = functools.partial(
|
||||
apply_sequence_parallelism,
|
||||
local_rank=self.local_rank,
|
||||
local_world_size=self.local_world_size,
|
||||
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
self._register_model_hooks()
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Remove all hooks
|
||||
for handle in self.hook_handles:
|
||||
handle.remove()
|
||||
self.hook_handles = []
|
||||
|
||||
# TODO(djsaunde): Un-patch attention and accelerate functions (low priority)
|
||||
|
||||
def _register_ring_attn(self):
|
||||
# Initialize ring attn for sequence parallelism
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=self.sequence_parallel_degree,
|
||||
heads_k_stride=self.heads_k_stride,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
# Patches for accelerate functionality
|
||||
patch_prepare_data_loader()
|
||||
patch_prepare_device_mesh(
|
||||
sequence_parallel_degree=self.sequence_parallel_degree
|
||||
)
|
||||
|
||||
def _register_model_hooks(self):
|
||||
# Forward pre-hook to apply sequence parallelism
|
||||
def sequence_parallel_pre_hook(_, args, kwargs):
|
||||
# Get parameter names from the model's forward function
|
||||
forward_params = list(
|
||||
inspect.signature(self.models[0].forward).parameters.keys()
|
||||
)
|
||||
|
||||
updated_kwargs = kwargs.copy()
|
||||
for i, arg in enumerate(args):
|
||||
if i < len(forward_params):
|
||||
updated_kwargs[forward_params[i]] = arg
|
||||
|
||||
# Any excess positional arguments are kept as-is
|
||||
remaining_args = args[len(forward_params) :]
|
||||
|
||||
# Apply sequence parallelism to updated kwargs
|
||||
updated_kwargs, self.original_seq_len, self.pad_len = (
|
||||
self.apply_sequence_parallelism(updated_kwargs)
|
||||
)
|
||||
|
||||
return remaining_args, updated_kwargs
|
||||
|
||||
# Forward post-hook to gather outputs
|
||||
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
|
||||
# Gather the sharded outputs
|
||||
output = self._gather_outputs(output)
|
||||
|
||||
# Remove padding if it was added
|
||||
if self.pad_len > 0:
|
||||
for key, value in output.items():
|
||||
if isinstance(value, torch.Tensor) and value.dim() > 1:
|
||||
if value.size(1) == self.original_seq_len + self.pad_len:
|
||||
# Slice to remove padding
|
||||
output[key] = value[:, : self.original_seq_len].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
# Register both hooks
|
||||
for model in self.models:
|
||||
self.hook_handles.append(
|
||||
model.register_forward_pre_hook(
|
||||
sequence_parallel_pre_hook, with_kwargs=True
|
||||
)
|
||||
)
|
||||
self.hook_handles.append(
|
||||
model.register_forward_hook(sequence_parallel_post_hook)
|
||||
)
|
||||
|
||||
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
|
||||
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
|
||||
for key, value in output.items():
|
||||
if isinstance(value, torch.Tensor) and value.dim() > 1:
|
||||
output[key] = AllGatherWithGrad.apply(value, self.process_group)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class AllGatherWithGrad(torch.autograd.Function):
|
||||
"""Custom autograd function for all-gather to preserve gradients."""
|
||||
|
||||
@@ -262,7 +262,7 @@ class AxolotlInputConfig(
|
||||
|
||||
val_set_size: float | None = Field(default=0.0)
|
||||
|
||||
context_parallel_degree: int | None = None
|
||||
sequence_parallel_degree: int | None = None
|
||||
heads_k_stride: int | None = None
|
||||
ring_attn_func: RingAttnFunc | None = None
|
||||
|
||||
@@ -1179,39 +1179,24 @@ class AxolotlInputConfig(
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_grpo_liger_context_parallel(cls, data):
|
||||
def check_grpo_liger_sequence_parallel(cls, data):
|
||||
if (
|
||||
data.get("rl") == "grpo"
|
||||
and data.get("trl", {})
|
||||
and data.get("trl").get("use_liger_loss")
|
||||
and data.get("context_parallel_degree", 1) > 1
|
||||
and data.get("sequence_parallel_degree", 1) > 1
|
||||
):
|
||||
raise ValueError("GRPO + CP + Liger not currently supported")
|
||||
raise ValueError("GRPO + SP + Liger not currently supported")
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_context_parallel_degree(self):
|
||||
if not self.context_parallel_degree:
|
||||
self.context_parallel_degree = 1
|
||||
elif self.context_parallel_degree > 1:
|
||||
import torch
|
||||
|
||||
world_size = torch.cuda.device_count()
|
||||
if not world_size >= self.context_parallel_degree:
|
||||
def check_sequence_parallel_degree(self):
|
||||
if not self.sequence_parallel_degree:
|
||||
self.sequence_parallel_degree = 1
|
||||
elif self.sequence_parallel_degree > 1:
|
||||
if not self.flash_attention:
|
||||
raise ValueError(
|
||||
f"World size ({world_size}) must be greater "
|
||||
f"than or equal to CP degree ({self.context_parallel_degree})"
|
||||
)
|
||||
if not world_size % self.context_parallel_degree == 0:
|
||||
raise ValueError(
|
||||
f"SP degree ({self.context_parallel_degree}) "
|
||||
f"must evenly divide world size ({world_size})"
|
||||
)
|
||||
|
||||
if not (self.flash_attention or self.sdp_attention):
|
||||
raise ValueError(
|
||||
"flash_attention: true or sdp_attention: true "
|
||||
"must be set with context_parallel_degree > 1"
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
)
|
||||
|
||||
if self.sample_packing and self.micro_batch_size > 1:
|
||||
@@ -1220,22 +1205,21 @@ class AxolotlInputConfig(
|
||||
"due to a `ring-flash-attn` requirement"
|
||||
)
|
||||
|
||||
if self.flash_attention:
|
||||
try:
|
||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"context_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
try:
|
||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
|
||||
# TODO: monkeypatch / callback to average losses correctly across CP ranks
|
||||
# / fix gradient scaling across CP ranks. Losses, grads should be scaled
|
||||
# TODO: monkeypatch / callback to average losses correctly across SP ranks
|
||||
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
|
||||
# according to the proportion of non-padding tokens per rank.
|
||||
LOG.warning(
|
||||
"Context parallelism (SP) is enabled with "
|
||||
f"context_parallel_degree={self.context_parallel_degree}. "
|
||||
"Sequence parallelism (SP) is enabled with "
|
||||
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
|
||||
"Please note that logged losses may differ slightly to the non-SP "
|
||||
"losses due to transformers Trainer implementation details. "
|
||||
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||
@@ -1246,7 +1230,7 @@ class AxolotlInputConfig(
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_ring_attn_func(self):
|
||||
if getattr(self, "context_parallel_degree", 1) == 1:
|
||||
if getattr(self, "sequence_parallel_degree", 1) == 1:
|
||||
return self
|
||||
|
||||
if self.ring_attn_func is not None:
|
||||
@@ -1275,7 +1259,7 @@ class AxolotlInputConfig(
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
"""Wrapper to validate GPU capabilities with the config options"""
|
||||
|
||||
capabilities: GPUCapabilities
|
||||
env_capabilities: EnvCapabilities
|
||||
|
||||
@@ -442,7 +442,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
- 1
|
||||
)
|
||||
* cfg.num_epochs
|
||||
* cfg.context_parallel_degree
|
||||
* cfg.sequence_parallel_degree
|
||||
)
|
||||
LOG.debug(
|
||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
|
||||
@@ -479,7 +479,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
# on the agreed on value for sample_packing_eff_est
|
||||
total_num_steps = int(
|
||||
math.floor(
|
||||
data_loader_len * cfg.num_epochs * cfg.context_parallel_degree
|
||||
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
||||
)
|
||||
)
|
||||
|
||||
@@ -502,7 +502,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
math.ceil(
|
||||
len(train_dataset)
|
||||
* cfg.num_epochs
|
||||
* cfg.context_parallel_degree
|
||||
* cfg.sequence_parallel_degree
|
||||
/ cfg.batch_size
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
shared pytest fixtures
|
||||
"""
|
||||
"""Shared pytest fixtures"""
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
@@ -559,3 +557,9 @@ def test_load_fixtures(
|
||||
download_llama2_model_fixture,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_telemetry(monkeypatch):
|
||||
monkeypatch.setenv("AXOLOTL_DO_NOT_TRACK", "1")
|
||||
yield
|
||||
|
||||
@@ -64,7 +64,7 @@ def fixture_base_cfg():
|
||||
"dataloader_num_workers": 1,
|
||||
"dataloader_pin_memory": True,
|
||||
"dataloader_prefetch_factor": 2,
|
||||
"context_parallel_degree": 1,
|
||||
"sequence_parallel_degree": 1,
|
||||
# Dtype
|
||||
"fp16": False,
|
||||
"bf16": False,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""E2E tests for context parallelism"""
|
||||
"""E2E tests for sequence parallelism"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
@@ -12,10 +12,10 @@ from axolotl.utils.dict import DictDefault
|
||||
from ...utils import check_tensorboard
|
||||
|
||||
|
||||
class TestContextParallelism:
|
||||
"""Test case for training with context parallelism enabled"""
|
||||
class TestSequenceParallelism:
|
||||
"""Test case for training with sequence parallelism enabled"""
|
||||
|
||||
def _run_context_parallel_test(
|
||||
def _run_sequence_parallel_test(
|
||||
self,
|
||||
temp_dir,
|
||||
sample_packing=True,
|
||||
@@ -24,7 +24,7 @@ class TestContextParallelism:
|
||||
ring_attn_func=None,
|
||||
threshold=2.0,
|
||||
):
|
||||
"""Helper method to run context parallel tests with different configurations"""
|
||||
"""Helper method to run sequence parallel tests with different configurations"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -66,7 +66,7 @@ class TestContextParallelism:
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"context_parallel_degree": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"ring_attn_func": ring_attn_func,
|
||||
}
|
||||
)
|
||||
@@ -109,7 +109,7 @@ class TestContextParallelism:
|
||||
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
],
|
||||
)
|
||||
def test_context_parallel_training(
|
||||
def test_sequence_parallel_training(
|
||||
self,
|
||||
temp_dir,
|
||||
sample_packing,
|
||||
@@ -118,8 +118,8 @@ class TestContextParallelism:
|
||||
ring_attn_func,
|
||||
threshold,
|
||||
):
|
||||
"""Test context parallel training with different configurations"""
|
||||
self._run_context_parallel_test(
|
||||
"""Test sequence parallel training with different configurations"""
|
||||
self._run_sequence_parallel_test(
|
||||
temp_dir,
|
||||
sample_packing=sample_packing,
|
||||
micro_batch_size=micro_batch_size,
|
||||
|
||||
@@ -296,7 +296,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"context_parallel_degree": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Tests for context parallelism functionality."""
|
||||
"""Tests for sequence parallelism functionality."""
|
||||
|
||||
# pylint: disable=redefined-outer-name,unused-argument
|
||||
|
||||
@@ -15,7 +15,7 @@ from axolotl.monkeypatch.ring_attn import (
|
||||
register_ring_attn,
|
||||
set_ring_attn_group,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.context_parallel import apply_context_parallelism
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
@@ -54,8 +54,8 @@ def fixture_cfg():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def context_parallel_batch():
|
||||
"""Create a test batch for context parallelism tests."""
|
||||
def sequence_parallel_batch():
|
||||
"""Create a test batch for sequence parallelism tests."""
|
||||
batch_size = 1
|
||||
seq_len = 8
|
||||
|
||||
@@ -110,7 +110,7 @@ class TestRingAttention:
|
||||
|
||||
# Call register_ring_attn with size 4
|
||||
register_ring_attn(
|
||||
context_parallel_degree=4,
|
||||
sequence_parallel_degree=4,
|
||||
heads_k_stride=1,
|
||||
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
|
||||
)
|
||||
@@ -126,7 +126,7 @@ class TestRingAttention:
|
||||
|
||||
|
||||
class TestConfigValidation:
|
||||
"""Tests for validating context parallelism configurations."""
|
||||
"""Tests for validating sequence parallelism configurations."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_mocks(self, monkeypatch):
|
||||
@@ -155,24 +155,24 @@ class TestConfigValidation:
|
||||
[
|
||||
# Valid configuration
|
||||
(
|
||||
{"context_parallel_degree": 2, "flash_attention": True},
|
||||
{"context_parallel_degree": 2, "flash_attention": True},
|
||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||
True,
|
||||
None,
|
||||
),
|
||||
# Default context_parallel_degree
|
||||
({}, {"context_parallel_degree": 1}, True, None),
|
||||
# Invalid: context_parallel_degree > 1 without flash_attention
|
||||
# Default sequence_parallel_degree
|
||||
({}, {"sequence_parallel_degree": 1}, True, None),
|
||||
# Invalid: sequence_parallel_degree > 1 without flash_attention
|
||||
(
|
||||
{"context_parallel_degree": 2, "flash_attention": False},
|
||||
{"sequence_parallel_degree": 2, "flash_attention": False},
|
||||
None,
|
||||
False,
|
||||
"flash_attention: true must be set",
|
||||
),
|
||||
# Invalid: context_parallel_degree > 1 with sample_packing and micro_batch_size > 1
|
||||
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
|
||||
(
|
||||
{
|
||||
"context_parallel_degree": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"micro_batch_size": 2,
|
||||
@@ -185,32 +185,32 @@ class TestConfigValidation:
|
||||
# Valid: Basic GRPO config
|
||||
(
|
||||
{
|
||||
"context_parallel_degree": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
},
|
||||
{
|
||||
"context_parallel_degree": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": TRLConfig(use_liger_loss=True),
|
||||
},
|
||||
True,
|
||||
"GRPO + CP + Liger not currently supported",
|
||||
"GRPO + SP + Liger not currently supported",
|
||||
),
|
||||
# Invalid: GRPO config with Liger loss
|
||||
(
|
||||
{
|
||||
"rl": "grpo",
|
||||
"context_parallel_degree": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
},
|
||||
None,
|
||||
False,
|
||||
"GRPO + CP + Liger not currently supported",
|
||||
"GRPO + SP + Liger not currently supported",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
@@ -222,10 +222,10 @@ class TestConfigValidation:
|
||||
"grpo_with_liger_loss",
|
||||
],
|
||||
)
|
||||
def test_context_parallel_config_validation(
|
||||
def test_sequence_parallel_config_validation(
|
||||
self, base_cfg, config_updates, expected_values, should_pass, error_msg
|
||||
):
|
||||
"""Test various context parallelism configuration scenarios."""
|
||||
"""Test various sequence parallelism configuration scenarios."""
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
# Apply updates to base config
|
||||
@@ -261,7 +261,7 @@ class TestConfigValidation:
|
||||
|
||||
# Apply updates to base config
|
||||
cfg = base_cfg | {
|
||||
"context_parallel_degree": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": sample_packing,
|
||||
}
|
||||
@@ -281,7 +281,7 @@ class TestConfigValidation:
|
||||
|
||||
# Invalid configuration with invalid ring_attn_func
|
||||
cfg = base_cfg | {
|
||||
"context_parallel_degree": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"ring_attn_func": "INVALID_FUNC",
|
||||
}
|
||||
@@ -294,8 +294,8 @@ class TestConfigValidation:
|
||||
assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value)
|
||||
|
||||
|
||||
class TestApplyContextParallelism:
|
||||
"""Tests for the apply_context_parallelism function."""
|
||||
class TestApplySequenceParallelism:
|
||||
"""Tests for the apply_sequence_parallelism function."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_distributed(self, monkeypatch):
|
||||
@@ -324,12 +324,12 @@ class TestApplyContextParallelism:
|
||||
)
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_world_size_one(self, mock_get_ring_attn_group, context_parallel_batch):
|
||||
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test that function returns original batch when world size is 1."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
result, _, _ = apply_context_parallelism(
|
||||
batch=context_parallel_batch,
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=sequence_parallel_batch,
|
||||
local_rank=0,
|
||||
local_world_size=1,
|
||||
gradient_accumulation_steps=1,
|
||||
@@ -337,17 +337,17 @@ class TestApplyContextParallelism:
|
||||
)
|
||||
|
||||
# Should return the original batch unchanged
|
||||
assert result == context_parallel_batch
|
||||
assert result == sequence_parallel_batch
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank0(self, mock_get_ring_attn_group, context_parallel_batch):
|
||||
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = context_parallel_batch
|
||||
batch = sequence_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
|
||||
result, _, _ = apply_context_parallelism(
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
@@ -366,15 +366,15 @@ class TestApplyContextParallelism:
|
||||
)
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank1(self, mock_get_ring_attn_group, context_parallel_batch):
|
||||
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = context_parallel_batch
|
||||
batch = sequence_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
result, _, _ = apply_context_parallelism(
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=1,
|
||||
local_world_size=2,
|
||||
@@ -386,14 +386,14 @@ class TestApplyContextParallelism:
|
||||
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
|
||||
|
||||
# TODO(djsaunde): add back once implemented.
|
||||
# def test_batch_zigzag(self, context_parallel_batch):
|
||||
# def test_batch_zigzag(self, sequence_parallel_batch):
|
||||
# """Test BATCH_ZIGZAG sharding pattern."""
|
||||
# batch = context_parallel_batch
|
||||
# batch = sequence_parallel_batch
|
||||
# original_input_ids = batch["input_ids"].clone()
|
||||
# seq_len = batch["input_ids"].size(1)
|
||||
|
||||
# # Test rank 0
|
||||
# result_rank0 = apply_context_parallelism(
|
||||
# result_rank0 = apply_sequence_parallelism(
|
||||
# batch={k: v.clone() for k, v in batch.items()},
|
||||
# local_rank=0,
|
||||
# local_world_size=2,
|
||||
@@ -401,7 +401,7 @@ class TestApplyContextParallelism:
|
||||
# )
|
||||
|
||||
# # Test rank 1
|
||||
# result_rank1 = apply_context_parallelism(
|
||||
# result_rank1 = apply_sequence_parallelism(
|
||||
# batch={k: v.clone() for k, v in batch.items()},
|
||||
# local_rank=1,
|
||||
# local_world_size=2,
|
||||
@@ -430,17 +430,17 @@ class TestApplyContextParallelism:
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_partial_application(
|
||||
self, mock_get_ring_attn_group, context_parallel_batch
|
||||
self, mock_get_ring_attn_group, sequence_parallel_batch
|
||||
):
|
||||
"""Test that we can create a partially applied version of the function."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = context_parallel_batch
|
||||
batch = sequence_parallel_batch
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
# Create a partially applied function
|
||||
rank0_ring_parallel = functools.partial(
|
||||
apply_context_parallelism,
|
||||
apply_sequence_parallelism,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
@@ -457,14 +457,16 @@ class TestApplyContextParallelism:
|
||||
original_input_ids[:, : original_input_ids.shape[1] // 2],
|
||||
)
|
||||
|
||||
def test_missing_position_ids(self, context_parallel_batch):
|
||||
def test_missing_position_ids(self, sequence_parallel_batch):
|
||||
"""Test handling of batch without position_ids."""
|
||||
# Create a batch without position_ids
|
||||
batch = {k: v for k, v in context_parallel_batch.items() if k != "position_ids"}
|
||||
batch = {
|
||||
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
|
||||
}
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
# This should run without error even though position_ids is missing
|
||||
result, _, _ = apply_context_parallelism(
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
|
||||
0
tests/telemetry/__init__.py
Normal file
0
tests/telemetry/__init__.py
Normal file
9
tests/telemetry/conftest.py
Normal file
9
tests/telemetry/conftest.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Shared pytest fixtures for telemetry tests."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_telemetry(monkeypatch):
|
||||
monkeypatch.delenv("AXOLOTL_DO_NOT_TRACK", raising=False)
|
||||
yield
|
||||
373
tests/telemetry/test_callbacks.py
Normal file
373
tests/telemetry/test_callbacks.py
Normal file
@@ -0,0 +1,373 @@
|
||||
"""Tests for telemetry callback module."""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||
|
||||
from axolotl.telemetry.callbacks import TIME_SINCE_LAST, TelemetryCallback
|
||||
|
||||
|
||||
def calc_expected_metrics(step, last_step, current_time, last_time, start_time=900.0):
|
||||
"""Calculate expected metrics values for tests"""
|
||||
time_diff = current_time - last_time
|
||||
step_diff = step - last_step
|
||||
return {
|
||||
"steps_per_second": (
|
||||
step_diff / time_diff if time_diff > 0 and step_diff > 0 else 0
|
||||
),
|
||||
"time_since_last_report": time_diff,
|
||||
"elapsed_time": current_time - start_time,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_time():
|
||||
"""Mock time.time() to have predictable values in tests"""
|
||||
with patch("axolotl.telemetry.callbacks.time") as mock_time:
|
||||
mock_time.time.return_value = 1000.0
|
||||
yield mock_time
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_telemetry_manager():
|
||||
"""Create a mock TelemetryManager"""
|
||||
with patch("axolotl.telemetry.callbacks.TelemetryManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
yield mock_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runtime_metrics_tracker():
|
||||
"""Create a mock RuntimeMetricsTracker"""
|
||||
with patch(
|
||||
"axolotl.telemetry.callbacks.RuntimeMetricsTracker"
|
||||
) as mock_tracker_class:
|
||||
mock_tracker = MagicMock()
|
||||
# Set up metrics property on the tracker
|
||||
mock_metrics = MagicMock()
|
||||
mock_metrics.to_dict.return_value = {
|
||||
"total_steps": 100,
|
||||
"peak_cpu_memory_bytes": 1024,
|
||||
}
|
||||
mock_tracker.metrics = mock_metrics
|
||||
|
||||
# Make the constructor return our mock
|
||||
mock_tracker_class.return_value = mock_tracker
|
||||
yield mock_tracker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def training_args():
|
||||
"""Create a minimal TrainingArguments instance"""
|
||||
return TrainingArguments(output_dir="./output")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trainer_state():
|
||||
"""Create a mock TrainerState"""
|
||||
state = MagicMock(spec=TrainerState)
|
||||
state.global_step = 10
|
||||
state.epoch = 0.5 # halfway through first epoch
|
||||
state.log_history = [{"loss": 2.5, "learning_rate": 5e-5}]
|
||||
return state
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trainer_control():
|
||||
"""Create a mock TrainerControl"""
|
||||
return MagicMock(spec=TrainerControl)
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@pytest.fixture
|
||||
def callback(mock_telemetry_manager, mock_runtime_metrics_tracker):
|
||||
"""Create a TelemetryCallback instance with mocked dependencies"""
|
||||
return TelemetryCallback()
|
||||
|
||||
|
||||
class TestTelemetryCallback:
|
||||
"""Tests for the TelemetryCallback class."""
|
||||
|
||||
def test_initialization(self, callback, mock_runtime_metrics_tracker):
|
||||
"""Test callback initialization."""
|
||||
assert callback.current_epoch == -1
|
||||
assert callback.tracker == mock_runtime_metrics_tracker
|
||||
assert callback.last_report_step == 0
|
||||
assert hasattr(callback, "start_time")
|
||||
assert hasattr(callback, "last_report_time")
|
||||
assert callback.report_interval_steps == 100
|
||||
|
||||
def test_on_train_begin(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_train_begin sends expected event."""
|
||||
callback.on_train_begin(training_args, trainer_state, trainer_control)
|
||||
|
||||
mock_telemetry_manager.send_event.assert_called_once_with(
|
||||
event_type="train-start"
|
||||
)
|
||||
|
||||
def test_on_train_end(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_train_end sends expected event with metrics."""
|
||||
callback.on_train_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||
|
||||
assert call_args["event_type"] == "train-end"
|
||||
assert "loss" in call_args["properties"]
|
||||
assert call_args["properties"]["loss"] == 2.5
|
||||
assert "learning_rate" in call_args["properties"]
|
||||
assert call_args["properties"]["learning_rate"] == 5e-5
|
||||
|
||||
# Check that metrics from RuntimeMetricsTracker are included
|
||||
assert "total_steps" in call_args["properties"]
|
||||
assert call_args["properties"]["total_steps"] == 100
|
||||
assert "peak_cpu_memory_bytes" in call_args["properties"]
|
||||
assert call_args["properties"]["peak_cpu_memory_bytes"] == 1024
|
||||
|
||||
def test_on_epoch_begin(
|
||||
self,
|
||||
callback,
|
||||
mock_runtime_metrics_tracker,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_epoch_begin updates epoch counter and calls tracker."""
|
||||
initial_epoch = callback.current_epoch
|
||||
|
||||
callback.on_epoch_begin(training_args, trainer_state, trainer_control)
|
||||
|
||||
assert callback.current_epoch == initial_epoch + 1
|
||||
mock_runtime_metrics_tracker.start_epoch.assert_called_once_with(
|
||||
initial_epoch + 1
|
||||
)
|
||||
|
||||
def test_on_epoch_end(
|
||||
self,
|
||||
callback,
|
||||
mock_runtime_metrics_tracker,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_epoch_end calls tracker."""
|
||||
# Set current epoch
|
||||
callback.current_epoch = 2
|
||||
|
||||
callback.on_epoch_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
mock_runtime_metrics_tracker.end_epoch.assert_called_once_with(2)
|
||||
|
||||
def test_on_step_end_no_report(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_step_end updates tracker but doesn't report if criteria not met."""
|
||||
# Set up state to avoid reporting
|
||||
trainer_state.global_step = 42 # Not divisible by report_interval_steps
|
||||
callback.last_report_step = 41 # Just 1 step since last report
|
||||
callback.last_report_time = time.time() # Just now
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should update tracker
|
||||
mock_runtime_metrics_tracker.update_step.assert_called_once_with(42)
|
||||
|
||||
# Should not send telemetry
|
||||
mock_telemetry_manager.send_event.assert_not_called()
|
||||
|
||||
# Should not update last report time/step
|
||||
assert callback.last_report_step == 41
|
||||
|
||||
def test_on_step_end_report_interval_steps(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker,
|
||||
mock_time,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_step_end reports when step interval is reached."""
|
||||
# Set up state with clear values
|
||||
current_step = 100 # Exactly matches report_interval_steps
|
||||
last_step = 0
|
||||
start_time = 900.0
|
||||
current_time = 1000.0
|
||||
time_diff = current_time - start_time # 100 seconds
|
||||
|
||||
# Configure state and callback
|
||||
trainer_state.global_step = current_step
|
||||
callback.report_interval_steps = 100
|
||||
callback.last_report_step = last_step
|
||||
callback.start_time = start_time
|
||||
callback.last_report_time = start_time
|
||||
|
||||
# Mock time.time() to return consistent values
|
||||
mock_time.time.return_value = current_time
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should update tracker
|
||||
mock_runtime_metrics_tracker.update_step.assert_called_once_with(current_step)
|
||||
mock_runtime_metrics_tracker.update_memory_metrics.assert_called_once()
|
||||
|
||||
# Should send telemetry
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||
assert call_args["event_type"] == "train-progress"
|
||||
|
||||
# Properties should include expected values
|
||||
props = call_args["properties"]
|
||||
assert props["step"] == current_step
|
||||
assert props["elapsed_time"] == time_diff # 1000 - 900 = 100
|
||||
assert props["time_since_last_report"] == time_diff # 1000 - 900 = 100
|
||||
assert props["steps_per_second"] == 1.0 # 100 steps / 100 seconds
|
||||
|
||||
# Should update last report time/step
|
||||
assert callback.last_report_step == current_step
|
||||
assert callback.last_report_time == current_time
|
||||
|
||||
def test_on_step_end_report_time_elapsed(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
|
||||
mock_time,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_step_end reports when enough time has elapsed."""
|
||||
# Set up state with clear values
|
||||
current_step = 120
|
||||
last_step = 10
|
||||
start_time = 900.0
|
||||
current_time = 1000.0
|
||||
time_diff = TIME_SINCE_LAST + 1 # Just over the threshold
|
||||
|
||||
# Configure state and callback
|
||||
trainer_state.global_step = current_step
|
||||
callback.report_interval_steps = 100
|
||||
callback.last_report_step = last_step
|
||||
callback.start_time = start_time
|
||||
callback.last_report_time = current_time - time_diff
|
||||
|
||||
# Mock time.time() to return consistent values
|
||||
mock_time.time.return_value = current_time
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should send telemetry
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
|
||||
# Properties should include expected values
|
||||
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
|
||||
expected_metrics = calc_expected_metrics(
|
||||
current_step, last_step, current_time, current_time - time_diff, start_time
|
||||
)
|
||||
assert props["steps_per_second"] == expected_metrics["steps_per_second"]
|
||||
assert (
|
||||
props["time_since_last_report"]
|
||||
== expected_metrics["time_since_last_report"]
|
||||
)
|
||||
|
||||
def test_on_step_end_first_step(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
|
||||
mock_time,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_step_end always reports on first step."""
|
||||
# Set up state with clear values
|
||||
current_step = 1 # First step
|
||||
last_step = 0
|
||||
start_time = 900.0
|
||||
current_time = 1000.0
|
||||
last_report_time = 999.0 # Just 1 second ago
|
||||
|
||||
# Configure state and callback
|
||||
trainer_state.global_step = current_step
|
||||
callback.report_interval_steps = 100
|
||||
callback.last_report_step = last_step
|
||||
callback.start_time = start_time
|
||||
callback.last_report_time = last_report_time
|
||||
|
||||
# Mock time.time() to return consistent values
|
||||
mock_time.time.return_value = current_time
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should send telemetry even though not much time has passed
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
|
||||
# Properties should include expected values for first step
|
||||
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
|
||||
assert props["step"] == current_step
|
||||
expected_metrics = calc_expected_metrics(
|
||||
current_step, last_step, current_time, last_report_time, start_time
|
||||
)
|
||||
assert props["steps_per_second"] == expected_metrics["steps_per_second"]
|
||||
|
||||
def test_log_history_empty(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
|
||||
mock_time,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test handling of empty log history."""
|
||||
# Set up state with clear values
|
||||
current_step = 1
|
||||
start_time = 900.0
|
||||
current_time = 1000.0
|
||||
|
||||
# Configure state and callback
|
||||
trainer_state.global_step = current_step
|
||||
trainer_state.log_history = []
|
||||
callback.start_time = start_time
|
||||
|
||||
# Mock time.time() to return consistent values
|
||||
mock_time.time.return_value = current_time
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should still send telemetry
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
|
||||
# Properties should have default values for missing log data
|
||||
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
|
||||
assert props["loss"] == 0
|
||||
assert props["learning_rate"] == 0
|
||||
341
tests/telemetry/test_errors.py
Normal file
341
tests/telemetry/test_errors.py
Normal file
@@ -0,0 +1,341 @@
|
||||
"""Tests for telemetry error utilities"""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.telemetry.errors import sanitize_stack_trace, send_errors
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_error_flag(monkeypatch):
|
||||
"""Reset ERROR_HANDLED flag using monkeypatch"""
|
||||
import axolotl.telemetry.errors
|
||||
|
||||
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
|
||||
yield
|
||||
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_stack_trace():
|
||||
"""Provide a sample stack trace with mixed paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
|
||||
trainer = get_trainer(cfg)
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/train.py", line 214, in get_trainer
|
||||
model = get_model(cfg, tokenizer)
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/models.py", line 120, in get_model
|
||||
raise ValueError("Model path not found")
|
||||
ValueError: Model path not found
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def windows_stack_trace():
|
||||
"""Provide a sample stack trace with Windows paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\cli\\train.py", line 83, in main
|
||||
trainer = get_trainer(cfg)
|
||||
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\train.py", line 214, in get_trainer
|
||||
model = get_model(cfg, tokenizer)
|
||||
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\auto\\modeling_auto.py", line 482, in from_pretrained
|
||||
raise ValueError(f"Unrecognized configuration class {config.__class__}")
|
||||
ValueError: Unrecognized configuration class <class 'transformers.models.llama.configuration_llama.LlamaConfig'>
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixed_stack_trace():
|
||||
"""Provide a sample stack trace with both axolotl and non-axolotl paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
|
||||
trainer = get_trainer(cfg)
|
||||
File "/home/user/.local/lib/python3.9/site-packages/transformers/trainer.py", line 520, in train
|
||||
self._inner_training_loop()
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/trainer.py", line 75, in _inner_training_loop
|
||||
super()._inner_training_loop()
|
||||
File "/home/user/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
|
||||
data = self._next_data()
|
||||
RuntimeError: CUDA out of memory
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def venv_stack_trace():
|
||||
"""Provide a sample stack trace with virtual environment paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 1729, in train
|
||||
self._inner_training_loop()
|
||||
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 2013, in _inner_training_loop
|
||||
self.accelerator.backward(loss)
|
||||
File "/home/user/venv/lib/python3.9/site-packages/accelerate/accelerator.py", line 1851, in backward
|
||||
self.scaler.scale(loss).backward(**kwargs)
|
||||
File "/home/user/venv/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
|
||||
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
|
||||
RuntimeError: CUDA out of memory
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dist_packages_stack_trace():
|
||||
"""Provide a sample stack trace with dist-packages paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 631, in __next__
|
||||
data = self._next_data()
|
||||
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 675, in _next_data
|
||||
data = self._dataset_fetcher.fetch(index)
|
||||
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
|
||||
data = [self.dataset[idx] for idx in possibly_batched_index]
|
||||
File "/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py", line 2808, in __getitem__
|
||||
raise IndexError(f"Index {key} out of range for dataset of length {len(self)}.")
|
||||
IndexError: Index 10000 out of range for dataset of length 9832.
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project_stack_trace():
|
||||
"""Provide a sample stack trace from a project directory (not a virtual env)"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/home/user/projects/myproject/run.py", line 25, in <module>
|
||||
main()
|
||||
File "/home/user/projects/myproject/src/cli.py", line 45, in main
|
||||
app.run()
|
||||
File "/home/user/projects/myproject/src/app.py", line 102, in run
|
||||
raise ValueError("Configuration missing")
|
||||
ValueError: Configuration missing
|
||||
"""
|
||||
|
||||
|
||||
def test_sanitize_stack_trace(example_stack_trace):
|
||||
"""Test that sanitize_stack_trace properly preserves axolotl paths"""
|
||||
sanitized = sanitize_stack_trace(example_stack_trace)
|
||||
|
||||
# Check that personal paths are removed
|
||||
assert "/home/user" not in sanitized
|
||||
assert ".local/lib/python3.9" not in sanitized
|
||||
|
||||
# Check that site-packages is preserved
|
||||
assert "site-packages/axolotl/cli/train.py" in sanitized
|
||||
assert "site-packages/axolotl/train.py" in sanitized
|
||||
assert "site-packages/axolotl/utils/models.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "ValueError: Model path not found" in sanitized
|
||||
|
||||
|
||||
def test_sanitize_windows_paths(windows_stack_trace):
|
||||
"""Test that sanitize_stack_trace handles Windows paths"""
|
||||
sanitized = sanitize_stack_trace(windows_stack_trace)
|
||||
|
||||
# Check that personal paths are removed
|
||||
assert "C:\\Users\\name" not in sanitized
|
||||
assert "AppData\\Local\\Programs\\Python" not in sanitized
|
||||
|
||||
# Check that both axolotl and transformers packages are preserved
|
||||
assert (
|
||||
"site-packages\\axolotl\\cli\\train.py" in sanitized
|
||||
or "site-packages/axolotl/cli/train.py" in sanitized
|
||||
)
|
||||
assert (
|
||||
"site-packages\\axolotl\\train.py" in sanitized
|
||||
or "site-packages/axolotl/train.py" in sanitized
|
||||
)
|
||||
assert (
|
||||
"site-packages\\transformers\\models\\auto\\modeling_auto.py" in sanitized
|
||||
or "site-packages/transformers/models/auto/modeling_auto.py" in sanitized
|
||||
)
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "ValueError: Unrecognized configuration class" in sanitized
|
||||
|
||||
|
||||
def test_sanitize_mixed_paths(mixed_stack_trace):
|
||||
"""Test that sanitize_stack_trace preserves all package paths"""
|
||||
sanitized = sanitize_stack_trace(mixed_stack_trace)
|
||||
|
||||
# Check that all package paths are preserved
|
||||
assert "site-packages/axolotl/cli/train.py" in sanitized
|
||||
assert "site-packages/transformers/trainer.py" in sanitized
|
||||
assert "site-packages/axolotl/utils/trainer.py" in sanitized
|
||||
assert "site-packages/torch/utils/data/dataloader.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "RuntimeError: CUDA out of memory" in sanitized
|
||||
|
||||
|
||||
def test_sanitize_venv_paths(venv_stack_trace):
|
||||
"""Test that sanitize_stack_trace preserves virtual environment package paths"""
|
||||
sanitized = sanitize_stack_trace(venv_stack_trace)
|
||||
|
||||
# Check that personal paths are removed
|
||||
assert "/home/user/venv" not in sanitized
|
||||
|
||||
# Check that all package paths are preserved
|
||||
assert "site-packages/transformers/trainer.py" in sanitized
|
||||
assert "site-packages/accelerate/accelerator.py" in sanitized
|
||||
assert "site-packages/torch/_tensor.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "RuntimeError: CUDA out of memory" in sanitized
|
||||
|
||||
|
||||
def test_sanitize_dist_packages(dist_packages_stack_trace):
|
||||
"""Test that sanitize_stack_trace preserves dist-packages paths"""
|
||||
sanitized = sanitize_stack_trace(dist_packages_stack_trace)
|
||||
|
||||
# Check that system paths are removed
|
||||
assert "/usr/local/lib/python3.8" not in sanitized
|
||||
|
||||
# Check that all package paths are preserved
|
||||
assert "dist-packages/torch/utils/data/dataloader.py" in sanitized
|
||||
assert "dist-packages/torch/utils/data/_utils/fetch.py" in sanitized
|
||||
assert "dist-packages/datasets/arrow_dataset.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert (
|
||||
"IndexError: Index 10000 out of range for dataset of length 9832." in sanitized
|
||||
)
|
||||
|
||||
|
||||
def test_sanitize_project_paths(project_stack_trace):
|
||||
"""Test handling of project paths (non-virtual env)"""
|
||||
sanitized = sanitize_stack_trace(project_stack_trace)
|
||||
|
||||
# Check that personal paths are removed
|
||||
assert "/home/user/projects" not in sanitized
|
||||
|
||||
# For non-package paths, we should at least preserve the filename
|
||||
assert "run.py" in sanitized
|
||||
assert "cli.py" in sanitized
|
||||
assert "app.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "ValueError: Configuration missing" in sanitized
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_telemetry_manager():
|
||||
"""Create a mock TelemetryManager"""
|
||||
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.enabled = True
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
yield mock_manager
|
||||
|
||||
|
||||
def test_send_errors_successful_execution(mock_telemetry_manager):
|
||||
"""Test that send_errors doesn't send telemetry for successful function execution"""
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
return "success"
|
||||
|
||||
result = test_func()
|
||||
assert result == "success"
|
||||
mock_telemetry_manager.send_event.assert_not_called()
|
||||
|
||||
|
||||
def test_send_errors_with_exception(mock_telemetry_manager):
|
||||
"""Test that send_errors sends telemetry when an exception occurs"""
|
||||
test_error = ValueError("Test error")
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
raise test_error
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
test_func()
|
||||
|
||||
assert excinfo.value == test_error
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
|
||||
# Check that the error info was passed correctly
|
||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||
assert "test_func-error" in call_args["event_type"]
|
||||
assert "Test error" in call_args["properties"]["exception"]
|
||||
assert "stack_trace" in call_args["properties"]
|
||||
|
||||
|
||||
def test_send_errors_nested_calls(mock_telemetry_manager):
|
||||
"""Test that send_errors only sends telemetry once for nested decorated functions"""
|
||||
|
||||
@send_errors
|
||||
def inner_func():
|
||||
raise ValueError("Inner error")
|
||||
|
||||
@send_errors
|
||||
def outer_func():
|
||||
return inner_func()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
outer_func()
|
||||
|
||||
# Telemetry should be sent only once for the inner function
|
||||
assert mock_telemetry_manager.send_event.call_count == 1
|
||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||
assert "inner_func-error" in call_args["event_type"]
|
||||
|
||||
|
||||
def test_send_errors_telemetry_disable():
|
||||
"""Test that send_errors doesn't attempt to send telemetry when disabled"""
|
||||
|
||||
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.enabled = False
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
raise ValueError("Test error")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func()
|
||||
|
||||
mock_manager.send_event.assert_not_called()
|
||||
|
||||
|
||||
def test_error_handled_reset():
|
||||
"""Test that ERROR_HANDLED flag is properly reset"""
|
||||
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
|
||||
# Create and configure the mock manager
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.enabled = True
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
|
||||
from axolotl.telemetry.errors import ERROR_HANDLED
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
raise ValueError("Test error")
|
||||
|
||||
assert not ERROR_HANDLED
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func()
|
||||
|
||||
from axolotl.telemetry.errors import ERROR_HANDLED
|
||||
|
||||
assert ERROR_HANDLED
|
||||
|
||||
|
||||
def test_module_path_resolution(mock_telemetry_manager):
|
||||
"""Test that the module path is correctly resolved for the event type"""
|
||||
import inspect
|
||||
|
||||
current_module = inspect.getmodule(test_module_path_resolution).__name__
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
raise ValueError("Test error")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func()
|
||||
|
||||
assert mock_telemetry_manager.send_event.called
|
||||
event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"]
|
||||
|
||||
expected_event_type = f"{current_module}.test_func-error"
|
||||
assert expected_event_type == event_type
|
||||
278
tests/telemetry/test_manager.py
Normal file
278
tests/telemetry/test_manager.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Tests for TelemetryManager class and utilities"""
|
||||
|
||||
# pylint: disable=redefined-outer-name,protected-access
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_whitelist(tmp_path):
|
||||
"""Create a temporary whitelist file for testing"""
|
||||
whitelist_content = {
|
||||
"organizations": ["meta-llama", "mistralai"],
|
||||
}
|
||||
whitelist_file = tmp_path / "whitelist.yaml"
|
||||
with open(whitelist_file, "w", encoding="utf-8") as f:
|
||||
yaml.dump(whitelist_content, f)
|
||||
|
||||
return str(whitelist_file)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def telemetry_manager_class():
|
||||
"""Reset the TelemetryManager singleton between tests"""
|
||||
original_instance = TelemetryManager._instance
|
||||
original_initialized = TelemetryManager._initialized
|
||||
TelemetryManager._instance = None
|
||||
TelemetryManager._initialized = False
|
||||
yield TelemetryManager
|
||||
TelemetryManager._instance = original_instance
|
||||
TelemetryManager._initialized = original_initialized
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(telemetry_manager_class, mock_whitelist):
|
||||
"""Create a TelemetryManager instance with mocked dependencies"""
|
||||
with (
|
||||
patch("posthog.capture"),
|
||||
patch("posthog.flush"),
|
||||
patch("time.sleep"),
|
||||
patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist),
|
||||
patch.dict(os.environ, {"RANK": "0"}),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
# Manually enable for most tests
|
||||
manager.enabled = True
|
||||
return manager
|
||||
|
||||
|
||||
def test_singleton_instance(telemetry_manager_class):
|
||||
"""Test that TelemetryManager is a singleton"""
|
||||
with (
|
||||
patch("posthog.capture"),
|
||||
patch("time.sleep"),
|
||||
patch.dict(os.environ, {"RANK": "0"}),
|
||||
):
|
||||
first = telemetry_manager_class()
|
||||
second = telemetry_manager_class()
|
||||
assert first is second
|
||||
assert telemetry_manager_class.get_instance() is first
|
||||
|
||||
|
||||
def test_telemetry_disabled_by_default(telemetry_manager_class):
|
||||
"""Test that telemetry is disabled by default (opt-in)"""
|
||||
with (
|
||||
patch.dict(os.environ, {"RANK": "0"}, clear=True),
|
||||
patch("time.sleep"),
|
||||
patch("logging.Logger.info"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_enabled_with_explicit_opt_in(telemetry_manager_class):
|
||||
"""Test that telemetry is enabled when AXOLOTL_DO_NOT_TRACK=0"""
|
||||
with (
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}),
|
||||
patch("time.sleep"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
|
||||
"""Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1"""
|
||||
with (
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}),
|
||||
patch("time.sleep"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_disabled_with_do_not_track(telemetry_manager_class):
|
||||
"""Test that telemetry is disabled when DO_NOT_TRACK=1"""
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "DO_NOT_TRACK": "1", "RANK": "0"}
|
||||
),
|
||||
patch("time.sleep"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
|
||||
"""Test that telemetry is disabled for non-main processes"""
|
||||
with (
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "1"}),
|
||||
patch("time.sleep"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
def test_opt_in_info_displayed(telemetry_manager_class):
|
||||
"""Test that opt-in info is displayed when telemetry is not configured"""
|
||||
with (
|
||||
patch.dict(os.environ, {"RANK": "0"}, clear=True),
|
||||
patch("logging.Logger.warning") as mock_warning,
|
||||
patch("time.sleep"),
|
||||
):
|
||||
telemetry_manager_class()
|
||||
info_displayed = False
|
||||
for call in mock_warning.call_args_list:
|
||||
print(f"call: {call}")
|
||||
if "Telemetry is currently disabled by default" in str(call):
|
||||
info_displayed = True
|
||||
break
|
||||
assert info_displayed
|
||||
|
||||
|
||||
def test_is_whitelisted(telemetry_manager_class, mock_whitelist):
|
||||
"""Test org whitelist functionality"""
|
||||
with (
|
||||
patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist),
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
|
||||
# Should match organizations from the mock whitelist
|
||||
assert manager._is_whitelisted("meta-llama/llama-7b")
|
||||
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
|
||||
# Should not match
|
||||
assert not manager._is_whitelisted("unknown/model")
|
||||
# Should handle case insensitively
|
||||
assert manager._is_whitelisted("META-LLAMA/Llama-7B")
|
||||
# Should handle empty input
|
||||
assert not manager._is_whitelisted("")
|
||||
|
||||
|
||||
def test_system_info_collection(manager):
|
||||
"""Test system information collection"""
|
||||
system_info = manager._get_system_info()
|
||||
|
||||
# Check essential keys
|
||||
assert "os" in system_info
|
||||
assert "python_version" in system_info
|
||||
assert "cpu_count" in system_info
|
||||
assert "memory_total" in system_info
|
||||
assert "accelerator_count" in system_info
|
||||
|
||||
|
||||
def test_send_event(telemetry_manager_class):
|
||||
"""Test basic event sending"""
|
||||
with (
|
||||
patch("posthog.capture") as mock_capture,
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
|
||||
# Test with clean properties (no PII)
|
||||
manager.send_event("test_event", {"key": "value"})
|
||||
assert mock_capture.called
|
||||
assert mock_capture.call_args[1]["event"] == "test_event"
|
||||
assert mock_capture.call_args[1]["properties"] == {"key": "value"}
|
||||
assert mock_capture.call_args[1]["distinct_id"] == manager.run_id
|
||||
|
||||
# Test with default properties (None)
|
||||
mock_capture.reset_mock()
|
||||
manager.send_event("simple_event")
|
||||
assert mock_capture.called
|
||||
assert mock_capture.call_args[1]["properties"] == {}
|
||||
|
||||
|
||||
def test_send_system_info(telemetry_manager_class):
|
||||
"""Test sending system info"""
|
||||
with (
|
||||
patch("posthog.capture") as mock_capture,
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
manager.send_system_info()
|
||||
assert mock_capture.called
|
||||
assert mock_capture.call_args[1]["event"] == "system-info"
|
||||
assert mock_capture.call_args[1]["properties"] == manager.system_info
|
||||
|
||||
|
||||
def test_redacted_properties(telemetry_manager_class):
|
||||
"""Test path redaction in send_event method"""
|
||||
with (
|
||||
patch("posthog.capture") as mock_capture,
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
# Test with properties containing various paths and non-paths
|
||||
test_properties = {
|
||||
"filepath": "/home/user/sensitive/data.txt",
|
||||
"windows_path": "C:\\Users\\name\\Documents\\project\\file.py",
|
||||
"output_dir": "/var/lib/data",
|
||||
"path_to_model": "models/llama/7b",
|
||||
"message": "Training started", # Should not be redacted
|
||||
"metrics": {"loss": 0.5, "accuracy": 0.95}, # Should not be redacted
|
||||
"base_model": "models/local_model",
|
||||
"nested": {
|
||||
"model_path": "/models/my_model",
|
||||
"root_dir": "/home/user/projects",
|
||||
"stats": {"steps": 1000, "epochs": 3}, # Should not be redacted
|
||||
},
|
||||
}
|
||||
|
||||
manager.send_event("test_event", test_properties)
|
||||
|
||||
# Verify the call was made
|
||||
assert mock_capture.called
|
||||
|
||||
# Get the sanitized properties that were sent
|
||||
sanitized = mock_capture.call_args[1]["properties"]
|
||||
|
||||
# Check that path-like and base_model keys were redacted
|
||||
assert sanitized["filepath"] == "[REDACTED]"
|
||||
assert sanitized["windows_path"] == "[REDACTED]"
|
||||
assert sanitized["path_to_model"] == "[REDACTED]"
|
||||
assert sanitized["base_model"] == "[REDACTED]"
|
||||
|
||||
# Check that non-path values were preserved
|
||||
assert sanitized["message"] == "Training started"
|
||||
assert sanitized["metrics"] == {"loss": 0.5, "accuracy": 0.95}
|
||||
|
||||
# Check nested structure handling
|
||||
assert sanitized["nested"]["model_path"] == "[REDACTED]"
|
||||
assert sanitized["nested"]["root_dir"] == "[REDACTED]"
|
||||
assert sanitized["nested"]["stats"] == {"steps": 1000, "epochs": 3}
|
||||
|
||||
|
||||
def test_disable_telemetry(manager):
|
||||
"""Test that disabled telemetry doesn't send events"""
|
||||
with patch("posthog.capture") as mock_capture:
|
||||
manager.enabled = False
|
||||
manager.send_event("test_event")
|
||||
assert not mock_capture.called
|
||||
|
||||
|
||||
def test_exception_handling_during_send(manager):
|
||||
"""Test that exceptions in PostHog are handled gracefully"""
|
||||
with (
|
||||
patch("posthog.capture", side_effect=Exception("Test error")),
|
||||
patch("logging.Logger.warning") as mock_warning,
|
||||
):
|
||||
manager.send_event("test_event")
|
||||
warning_logged = False
|
||||
for call in mock_warning.call_args_list:
|
||||
if "Failed to send telemetry event" in str(call):
|
||||
warning_logged = True
|
||||
break
|
||||
assert warning_logged
|
||||
|
||||
|
||||
def test_shutdown(manager):
|
||||
"""Test shutdown behavior"""
|
||||
with patch("posthog.flush") as mock_flush:
|
||||
manager.shutdown()
|
||||
assert mock_flush.called
|
||||
357
tests/telemetry/test_runtime_metrics.py
Normal file
357
tests/telemetry/test_runtime_metrics.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""Tests for runtime metrics telemetry module"""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.telemetry.runtime_metrics import RuntimeMetrics, RuntimeMetricsTracker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_time():
|
||||
"""Mock time.time() to have predictable values in tests"""
|
||||
with patch("time.time") as mock_time:
|
||||
# Start with time 1000.0 and increment by 10 seconds on each call
|
||||
times = [1000.0 + i * 10 for i in range(10)]
|
||||
mock_time.side_effect = times
|
||||
yield mock_time
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_telemetry_manager():
|
||||
"""Create a mock TelemetryManager"""
|
||||
with patch(
|
||||
"axolotl.telemetry.runtime_metrics.TelemetryManager"
|
||||
) as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.enabled = True
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
yield mock_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_psutil():
|
||||
"""Mock psutil for memory information"""
|
||||
with patch("axolotl.telemetry.runtime_metrics.psutil") as mock_psutil:
|
||||
mock_process = MagicMock()
|
||||
mock_memory_info = MagicMock()
|
||||
# Set initial memory to 1GB
|
||||
mock_memory_info.rss = 1024 * 1024 * 1024
|
||||
mock_process.memory_info.return_value = mock_memory_info
|
||||
mock_psutil.Process.return_value = mock_process
|
||||
yield mock_psutil
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_torch():
|
||||
"""Mock torch.cuda functions"""
|
||||
with patch("axolotl.telemetry.runtime_metrics.torch") as mock_torch:
|
||||
mock_torch.cuda.is_available.return_value = True
|
||||
mock_torch.cuda.device_count.return_value = 2
|
||||
|
||||
# Mock memory allocated per device (1GB for device 0, 2GB for device 1)
|
||||
mock_torch.cuda.memory_allocated.side_effect = (
|
||||
lambda device: (device + 1) * 1024 * 1024 * 1024
|
||||
)
|
||||
|
||||
yield mock_torch
|
||||
|
||||
|
||||
class TestRuntimeMetrics:
|
||||
"""Tests for RuntimeMetrics class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test RuntimeMetrics initialization."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
assert metrics.start_time == 1000.0
|
||||
assert metrics.epoch_start_times == {}
|
||||
assert metrics.epoch_end_times == {}
|
||||
assert metrics.peak_gpu_memory == {}
|
||||
assert metrics.total_steps == 0
|
||||
assert metrics.current_epoch == 0
|
||||
assert metrics.current_step == 0
|
||||
assert metrics.peak_cpu_memory == 0
|
||||
|
||||
def test_elapsed_time(self, mock_time):
|
||||
"""Test elapsed_time property."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
# Mock time.time() to return 1050.0
|
||||
mock_time.side_effect = [1050.0]
|
||||
|
||||
assert metrics.elapsed_time == 50.0
|
||||
|
||||
def test_epoch_time(self):
|
||||
"""Test epoch_time method."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
# No epoch data
|
||||
assert metrics.epoch_time(0) is None
|
||||
|
||||
# Add epoch start but no end
|
||||
metrics.epoch_start_times[0] = 1000.0
|
||||
assert metrics.epoch_time(0) is None
|
||||
|
||||
# Add epoch end
|
||||
metrics.epoch_end_times[0] = 1060.0
|
||||
assert metrics.epoch_time(0) == 60.0
|
||||
|
||||
def test_average_epoch_time(self):
|
||||
"""Test average_epoch_time method."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
# No completed epochs
|
||||
assert metrics.average_epoch_time() is None
|
||||
|
||||
# Add one completed epoch
|
||||
metrics.epoch_start_times[0] = 1000.0
|
||||
metrics.epoch_end_times[0] = 1060.0
|
||||
assert metrics.average_epoch_time() == 60.0
|
||||
|
||||
# Add second completed epoch
|
||||
metrics.epoch_start_times[1] = 1060.0
|
||||
metrics.epoch_end_times[1] = 1140.0 # 80 seconds
|
||||
assert metrics.average_epoch_time() == 70.0 # Average of 60 and 80
|
||||
|
||||
# Add incomplete epoch (should not affect average)
|
||||
metrics.epoch_start_times[2] = 1140.0
|
||||
assert metrics.average_epoch_time() == 70.0
|
||||
|
||||
def test_steps_per_second(self, mock_time):
|
||||
"""Test steps_per_second method."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
# No steps - first call to time.time()
|
||||
mock_time.side_effect = None
|
||||
mock_time.return_value = 1050.0
|
||||
assert metrics.steps_per_second() is None
|
||||
|
||||
# Add steps - second call to time.time()
|
||||
metrics.total_steps = 100
|
||||
mock_time.return_value = 1050.0 # Keep same time for consistent result
|
||||
assert metrics.steps_per_second() == 2.0 # 100 steps / 50 seconds
|
||||
|
||||
def test_to_dict_basic(self, mock_time):
|
||||
"""Test to_dict method with basic metrics."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
metrics.total_steps = 100
|
||||
metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024 # 2GB
|
||||
|
||||
# Mock elapsed_time
|
||||
mock_time.side_effect = None
|
||||
mock_time.return_value = 1050.0
|
||||
|
||||
result = metrics.to_dict()
|
||||
|
||||
assert result["total_time_seconds"] == 50.0
|
||||
assert result["total_steps"] == 100
|
||||
assert result["steps_per_second"] == 2.0
|
||||
assert result["epochs_completed"] == 0
|
||||
assert result["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||
assert "epoch_times" not in result
|
||||
assert "gpu_memory" not in result
|
||||
|
||||
def test_to_dict_with_epochs(self, mock_time):
|
||||
"""Test to_dict method with epoch data."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
metrics.total_steps = 100
|
||||
|
||||
# Add epoch data
|
||||
metrics.epoch_start_times[0] = 1000.0
|
||||
metrics.epoch_end_times[0] = 1060.0
|
||||
metrics.epoch_start_times[1] = 1060.0
|
||||
metrics.epoch_end_times[1] = 1140.0
|
||||
|
||||
# Mock elapsed_time
|
||||
mock_time.side_effect = None
|
||||
mock_time.return_value = 1150.0
|
||||
|
||||
result = metrics.to_dict()
|
||||
|
||||
assert "epoch_times" in result
|
||||
assert result["epoch_times"]["epoch_0_seconds"] == 60.0
|
||||
assert result["epoch_times"]["epoch_1_seconds"] == 80.0
|
||||
assert result["average_epoch_time_seconds"] == 70.0
|
||||
|
||||
def test_to_dict_with_gpu_memory(self, mock_time):
|
||||
"""Test to_dict method with GPU memory data."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
metrics.peak_gpu_memory = {
|
||||
0: 1 * 1024 * 1024 * 1024, # 1GB
|
||||
1: 2 * 1024 * 1024 * 1024, # 2GB
|
||||
}
|
||||
|
||||
# Mock elapsed_time
|
||||
mock_time.side_effect = [1050.0]
|
||||
|
||||
result = metrics.to_dict()
|
||||
|
||||
assert "gpu_memory" in result
|
||||
assert result["gpu_memory"]["gpu_0_peak_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
||||
assert result["gpu_memory"]["gpu_1_peak_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||
|
||||
|
||||
class TestRuntimeMetricsTracker:
|
||||
"""Tests for RuntimeMetricsTracker class."""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_initialization(self, mock_time, mock_telemetry_manager):
|
||||
"""Test RuntimeMetricsTracker initialization."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
assert isinstance(tracker.metrics, RuntimeMetrics)
|
||||
assert tracker.metrics.start_time == 1000.0 # First value from mock_time
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_start_epoch(
|
||||
self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager
|
||||
):
|
||||
"""Test start_epoch method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Reset mock_time to control next value
|
||||
mock_time.side_effect = [1010.0]
|
||||
|
||||
tracker.start_epoch(0)
|
||||
|
||||
assert tracker.metrics.current_epoch == 0
|
||||
assert tracker.metrics.epoch_start_times[0] == 1010.0
|
||||
|
||||
# Verify memory metrics were updated
|
||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
||||
assert 0 in tracker.metrics.peak_gpu_memory
|
||||
assert 1 in tracker.metrics.peak_gpu_memory
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_end_epoch(self, mock_time, mock_telemetry_manager):
|
||||
"""Test end_epoch method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Start epoch 0
|
||||
mock_time.side_effect = [1010.0]
|
||||
tracker.start_epoch(0)
|
||||
|
||||
# End epoch 0
|
||||
mock_time.side_effect = [1060.0]
|
||||
tracker.end_epoch(0)
|
||||
|
||||
assert 0 in tracker.metrics.epoch_end_times
|
||||
assert tracker.metrics.epoch_end_times[0] == 1060.0
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_update_step(
|
||||
self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager
|
||||
):
|
||||
"""Test update_step method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Update step to a non-multiple of 100
|
||||
tracker.update_step(42)
|
||||
|
||||
assert tracker.metrics.current_step == 42
|
||||
assert tracker.metrics.total_steps == 1
|
||||
|
||||
# Memory metrics should not be updated for non-multiple of 100
|
||||
assert tracker.metrics.peak_cpu_memory == 0
|
||||
|
||||
# Update step to a multiple of 100
|
||||
tracker.update_step(100)
|
||||
|
||||
assert tracker.metrics.current_step == 100
|
||||
assert tracker.metrics.total_steps == 2
|
||||
|
||||
# Memory metrics should be updated for multiple of 100
|
||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_update_memory_metrics(
|
||||
self, mock_psutil, mock_torch, mock_telemetry_manager
|
||||
):
|
||||
"""Test update_memory_metrics method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Initial memory state
|
||||
assert tracker.metrics.peak_cpu_memory == 0
|
||||
assert tracker.metrics.peak_gpu_memory == {}
|
||||
|
||||
# Update memory metrics
|
||||
tracker.update_memory_metrics()
|
||||
|
||||
# Verify CPU memory
|
||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
||||
|
||||
# Verify GPU memory
|
||||
assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024
|
||||
|
||||
# Change mocked memory values to be lower
|
||||
mock_process = mock_psutil.Process.return_value
|
||||
mock_memory_info = mock_process.memory_info.return_value
|
||||
mock_memory_info.rss = 0.5 * 1024 * 1024 * 1024 # 0.5GB
|
||||
|
||||
mock_torch.cuda.memory_allocated.side_effect = (
|
||||
lambda device: (device + 0.5) * 1024 * 1024 * 1024
|
||||
)
|
||||
|
||||
# Update memory metrics again
|
||||
tracker.update_memory_metrics()
|
||||
|
||||
# Peak values should not decrease
|
||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024
|
||||
|
||||
# Change mocked memory values to be higher
|
||||
mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB
|
||||
|
||||
mock_torch.cuda.memory_allocated.side_effect = (
|
||||
lambda device: (device + 2) * 1024 * 1024 * 1024
|
||||
)
|
||||
|
||||
# Update memory metrics again
|
||||
tracker.update_memory_metrics()
|
||||
|
||||
# Peak values should increase
|
||||
assert tracker.metrics.peak_cpu_memory == 2 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[0] == 2 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[1] == 3 * 1024 * 1024 * 1024
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_get_memory_metrics(self, mock_psutil, mock_torch, mock_telemetry_manager):
|
||||
"""Test get_memory_metrics method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Set peak memory values
|
||||
tracker.metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024
|
||||
tracker.metrics.peak_gpu_memory = {
|
||||
0: 3 * 1024 * 1024 * 1024,
|
||||
1: 4 * 1024 * 1024 * 1024,
|
||||
}
|
||||
|
||||
# Get memory metrics
|
||||
memory_metrics = tracker.get_memory_metrics()
|
||||
|
||||
# Verify CPU memory
|
||||
assert (
|
||||
memory_metrics["cpu_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
||||
) # Current value from mock
|
||||
assert (
|
||||
memory_metrics["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||
) # Peak value we set
|
||||
|
||||
# Verify GPU memory
|
||||
assert (
|
||||
memory_metrics["gpu_0_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
||||
) # Current value from mock
|
||||
assert (
|
||||
memory_metrics["gpu_0_peak_memory_bytes"] == 3 * 1024 * 1024 * 1024
|
||||
) # Peak value we set
|
||||
assert (
|
||||
memory_metrics["gpu_1_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||
) # Current value from mock
|
||||
assert (
|
||||
memory_metrics["gpu_1_peak_memory_bytes"] == 4 * 1024 * 1024 * 1024
|
||||
) # Peak value we set
|
||||
Reference in New Issue
Block a user