Merge branch 'main' into telemetry-opt-in
This commit is contained in:
28
.github/workflows/base.yml
vendored
28
.github/workflows/base.yml
vendored
@@ -53,6 +53,20 @@ jobs:
|
|||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
|
- cuda: "128"
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-base"
|
||||||
|
- cuda: "130"
|
||||||
|
cuda_version: 13.0.0
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.0
|
||||||
|
torch_cuda_arch_list: "9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-base"
|
||||||
# - cuda: "128"
|
# - cuda: "128"
|
||||||
# cuda_version: 12.8.1
|
# cuda_version: 12.8.1
|
||||||
# cudnn_version: ""
|
# cudnn_version: ""
|
||||||
@@ -129,6 +143,20 @@ jobs:
|
|||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
|
- cuda: "128"
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-uv-base"
|
||||||
|
- cuda: "130"
|
||||||
|
cuda_version: 13.0.0
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.0
|
||||||
|
torch_cuda_arch_list: "9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-uv-base"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ repos:
|
|||||||
- id: no-commit-to-branch
|
- id: no-commit-to-branch
|
||||||
args: ['--branch', 'main']
|
args: ['--branch', 'main']
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.13.3
|
rev: v0.14.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix]
|
args: [--fix]
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
||||||
|
|
||||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||||
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
||||||
ENV CUDA="{{ CUDA }}"
|
ENV CUDA="{{ CUDA }}"
|
||||||
|
|||||||
@@ -37,16 +37,22 @@ WORKDIR /workspace
|
|||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||||
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
|
|
||||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
|
|
||||||
python3 -m pip cache purge
|
python3 -m pip cache purge
|
||||||
|
|
||||||
|
RUN if [ "$CUDA" != "130" ] ; then \
|
||||||
|
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@v1.5.4"; \
|
||||||
|
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
|
||||||
|
python3 -m pip cache purge; \
|
||||||
|
fi
|
||||||
|
|
||||||
RUN git lfs install --skip-repo && \
|
RUN git lfs install --skip-repo && \
|
||||||
pip3 install awscli && \
|
pip3 install awscli && \
|
||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||||
pip3 cache purge
|
pip3 cache purge
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
|
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
||||||
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
|
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -34,3 +34,9 @@ RUN uv pip install packaging setuptools wheel psutil \
|
|||||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||||
&& uv pip install awscli pydantic
|
&& uv pip install awscli pydantic
|
||||||
|
|
||||||
|
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
||||||
|
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
fi
|
||||||
|
|||||||
@@ -63,6 +63,14 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
|
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
|
||||||
|
|
||||||
|
**Q: Can we mix text and text+image datasets for VLM training?**
|
||||||
|
|
||||||
|
> A: Yes, you can for newer VLM arch. The ones that would not work are LLaVA / Pixtral arch. If you notice one not working, please let us know!
|
||||||
|
|
||||||
|
**Q: Why is `memory/max_*` different from `nvidia-smi`?**
|
||||||
|
|
||||||
|
> A: We use `torch` APIs to retrieve this information. You can see https://docs.pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management for more information.
|
||||||
|
|
||||||
### Chat templates
|
### Chat templates
|
||||||
|
|
||||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||||
|
|||||||
@@ -27,3 +27,9 @@ learning_rate: 2e-5
|
|||||||
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
|
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
|
||||||
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
|
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
|
||||||
self attention `q_proj` module.
|
self attention `q_proj` module.
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
|
||||||
|
We currently only support varying `lr` for now. If you're interested in adding support for others (`weight_decay`), we welcome PRs. See https://github.com/axolotl-ai-cloud/axolotl/blob/613bcf90e58f3ab81d3827e7fc572319908db9fb/src/axolotl/core/trainers/mixins/optimizer.py#L17
|
||||||
|
|
||||||
|
:::
|
||||||
|
|||||||
@@ -56,10 +56,14 @@ image_resize_algorithm: bilinear
|
|||||||
|
|
||||||
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
|
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
|
||||||
|
|
||||||
::: {.callout-warning}
|
::: {.callout-tip}
|
||||||
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
|
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
As of now, we do not truncate nor drop samples based on `sequence_len` as each arch has different ways to process non-text tokens. We are looking for help on this.
|
||||||
|
:::
|
||||||
|
|
||||||
### Mllama {#sec-mllama}
|
### Mllama {#sec-mllama}
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
@@ -168,6 +172,14 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
|||||||
chat_template: qwen2_vl # same as qwen2-vl
|
chat_template: qwen2_vl # same as qwen2-vl
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Qwen3-VL {#sec-qwen3-vl}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: Qwen/Qwen3-VL-4B-Instruct
|
||||||
|
|
||||||
|
chat_template: qwen2_vl # same as qwen2-vl
|
||||||
|
```
|
||||||
|
|
||||||
### SmolVLM2 {#sec-smolvlm2}
|
### SmolVLM2 {#sec-smolvlm2}
|
||||||
|
|
||||||
::: {.callout-tip}
|
::: {.callout-tip}
|
||||||
|
|||||||
@@ -219,6 +219,21 @@ DPO supports the following types with the following dataset format:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### chat_template.argilla_chat
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"chosen": [
|
||||||
|
{"role": "user", "content": "..."},
|
||||||
|
{"role": "assistant", "content": "..."}
|
||||||
|
],
|
||||||
|
"rejected": [
|
||||||
|
{"role": "user", "content": "..."},
|
||||||
|
{"role": "assistant", "content": "..."}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
#### chat_template.default
|
#### chat_template.default
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
50
examples/llama-3/opentelemetry-qlora.yml
Normal file
50
examples/llama-3/opentelemetry-qlora.yml
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
base_model: NousResearch/Llama-3.2-1B
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
|
||||||
|
output_dir: ./outputs/opentelemetry-example
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
sequence_len: 512
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
|
||||||
|
# OpenTelemetry Configuration
|
||||||
|
use_otel_metrics: true
|
||||||
|
otel_metrics_host: "localhost"
|
||||||
|
otel_metrics_port: 8000
|
||||||
|
|
||||||
|
# Disable WandB
|
||||||
|
use_wandb: false
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_32bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: false
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 2
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
pad_token: "<|end_of_text|>"
|
||||||
@@ -12,7 +12,7 @@ Before starting, ensure you have:
|
|||||||
Run the thinking model fine-tuning:
|
Run the thinking model fine-tuning:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
axolotl train magistral-small-think-qlora.yaml
|
axolotl train examples/magistral/think/magistral-small-think-qlora.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
This config uses about 19.1 GiB VRAM.
|
This config uses about 19.1 GiB VRAM.
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ Before starting, ensure you have:
|
|||||||
|
|
||||||
3. Run the fine-tuning:
|
3. Run the fine-tuning:
|
||||||
```bash
|
```bash
|
||||||
axolotl train magistral-small-vision-24B-qlora.yml
|
axolotl train examples/magistral/vision/magistral-small-vision-24B-qlora.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
This config uses about 17GiB VRAM.
|
This config uses about 17GiB VRAM.
|
||||||
|
|||||||
51
examples/mistral/mistral-small/README.md
Normal file
51
examples/mistral/mistral-small/README.md
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# Mistral Small 3.1/3.2 Fine-tuning
|
||||||
|
|
||||||
|
This guide covers fine-tuning [Mistral Small 3.1](mistralai/Mistral-Small-3.1-24B-Instruct-2503) and [Mistral Small 3.2](mistralai/Mistral-Small-3.2-24B-Instruct-2506) with vision capabilities using Axolotl.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
Before starting, ensure you have:
|
||||||
|
- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
1. Install the required vision lib:
|
||||||
|
```bash
|
||||||
|
pip install 'mistral-common[opencv]==1.8.5'
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Download the example dataset image:
|
||||||
|
```bash
|
||||||
|
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Run the fine-tuning:
|
||||||
|
```bash
|
||||||
|
axolotl train examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
This config uses about 29.4 GiB VRAM.
|
||||||
|
|
||||||
|
## Dataset Format
|
||||||
|
|
||||||
|
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||||
|
|
||||||
|
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
||||||
|
{"role": "user", "content": [
|
||||||
|
{ "type": "text", "text": "What's in this image?"},
|
||||||
|
{"type": "image", "path": "path/to/image.jpg" }
|
||||||
|
]},
|
||||||
|
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- Sample Packing is not supported for multi-modality training currently.
|
||||||
@@ -39,7 +39,7 @@ wandb_name:
|
|||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 2
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
|
|||||||
@@ -5,27 +5,27 @@ bitsandbytes==0.47.0
|
|||||||
triton>=3.0.0
|
triton>=3.0.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
liger-kernel==0.6.1
|
liger-kernel==0.6.3
|
||||||
# END section
|
# END section
|
||||||
|
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
huggingface_hub>=0.33.0
|
huggingface_hub>=0.36.0
|
||||||
peft>=0.17.1
|
peft>=0.17.1
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
transformers==4.57.0
|
transformers==4.57.1
|
||||||
accelerate==1.10.1
|
accelerate==1.10.1
|
||||||
datasets==4.0.0
|
datasets==4.0.0
|
||||||
deepspeed>=0.17.0
|
deepspeed>=0.17.0
|
||||||
trl==0.23.0
|
trl==0.24.0
|
||||||
hf_xet==1.1.5
|
hf_xet==1.2.0
|
||||||
kernels==0.9.0
|
kernels>=0.9.0
|
||||||
trackio
|
trackio
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
sentencepiece
|
sentencepiece
|
||||||
gradio==5.41.1
|
gradio==5.49.1
|
||||||
|
|
||||||
modal==1.0.2
|
modal==1.0.2
|
||||||
pydantic==2.10.6
|
pydantic==2.10.6
|
||||||
|
|||||||
18
setup.py
18
setup.py
@@ -49,7 +49,7 @@ def parse_requirements(extras_require_map):
|
|||||||
try:
|
try:
|
||||||
torch_version = version("torch")
|
torch_version = version("torch")
|
||||||
except PackageNotFoundError:
|
except PackageNotFoundError:
|
||||||
torch_version = "2.6.0" # default to torch 2.6
|
torch_version = "2.8.0" # default to torch 2.8.0
|
||||||
_install_requires.append(f"torch=={torch_version}")
|
_install_requires.append(f"torch=={torch_version}")
|
||||||
|
|
||||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||||
@@ -62,8 +62,12 @@ def parse_requirements(extras_require_map):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid version format")
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
if (major, minor) >= (2, 8):
|
if (major, minor) >= (2, 9):
|
||||||
pass
|
extras_require_map.pop("fbgemm-gpu")
|
||||||
|
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
|
||||||
|
elif (major, minor) >= (2, 8):
|
||||||
|
extras_require_map.pop("fbgemm-gpu")
|
||||||
|
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
|
||||||
elif (major, minor) >= (2, 7):
|
elif (major, minor) >= (2, 7):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
@@ -158,7 +162,13 @@ extras_require = {
|
|||||||
"llmcompressor": [
|
"llmcompressor": [
|
||||||
"llmcompressor==0.5.1",
|
"llmcompressor==0.5.1",
|
||||||
],
|
],
|
||||||
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
|
"fbgemm-gpu": ["fbgemm-gpu-genai==1.3.0"],
|
||||||
|
"opentelemetry": [
|
||||||
|
"opentelemetry-api",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"opentelemetry-exporter-prometheus",
|
||||||
|
"prometheus-client",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||||
extras_require
|
extras_require
|
||||||
|
|||||||
@@ -12,7 +12,9 @@ MOE_ARCH_BLOCK = {
|
|||||||
"mixtral": "MixtralSparseMoeBlock",
|
"mixtral": "MixtralSparseMoeBlock",
|
||||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||||
|
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||||
"deepseek_v2": "DeepseekV2MoE",
|
"deepseek_v2": "DeepseekV2MoE",
|
||||||
|
"deepseek_v3": "DeepseekV3MoE",
|
||||||
"gpt_oss": "GptOssDecoderLayer",
|
"gpt_oss": "GptOssDecoderLayer",
|
||||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,7 +31,11 @@ from axolotl.integrations.base import PluginManager
|
|||||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||||
from axolotl.telemetry.callbacks import TelemetryCallback
|
from axolotl.telemetry.callbacks import TelemetryCallback
|
||||||
from axolotl.telemetry.manager import TelemetryManager
|
from axolotl.telemetry.manager import TelemetryManager
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import (
|
||||||
|
is_comet_available,
|
||||||
|
is_mlflow_available,
|
||||||
|
is_opentelemetry_available,
|
||||||
|
)
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
GCCallback,
|
GCCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
@@ -136,6 +140,12 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
if self.cfg.use_otel_metrics and is_opentelemetry_available():
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
callbacks.append(OpenTelemetryMetricsCallback(self.cfg))
|
||||||
if self.cfg.save_first_step:
|
if self.cfg.save_first_step:
|
||||||
callbacks.append(SaveModelOnFirstStepCallback())
|
callbacks.append(SaveModelOnFirstStepCallback())
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from transformers import (
|
|||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
Trainer,
|
Trainer,
|
||||||
)
|
)
|
||||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
from trl.trainer.reward_trainer import DataCollatorForPreference
|
||||||
|
|
||||||
from axolotl.core.builders.base import TrainerBuilderBase
|
from axolotl.core.builders.base import TrainerBuilderBase
|
||||||
from axolotl.core.trainers import (
|
from axolotl.core.trainers import (
|
||||||
@@ -453,7 +453,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
DataCollatorWithFlattening,
|
DataCollatorWithFlattening,
|
||||||
RewardDataCollatorWithPadding,
|
DataCollatorForPreference,
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
collator_args = [self.tokenizer]
|
collator_args = [self.tokenizer]
|
||||||
@@ -470,7 +470,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if kwargs and isinstance(kwargs, dict):
|
if kwargs and isinstance(kwargs, dict):
|
||||||
kwargs.update(collator_cls_and_kwargs[1])
|
kwargs.update(collator_cls_and_kwargs[1])
|
||||||
elif self.cfg.reward_model:
|
elif self.cfg.reward_model:
|
||||||
collator = RewardDataCollatorWithPadding
|
collator = DataCollatorForPreference
|
||||||
|
tokenizer = collator_args.pop(0)
|
||||||
|
kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||||
|
kwargs.pop("padding")
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
||||||
# supported multipack models, or non-flash-attention llama
|
# supported multipack models, or non-flash-attention llama
|
||||||
|
|||||||
@@ -225,17 +225,6 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
data_collator = self.data_collator if is_training else self.eval_data_collator
|
data_collator = self.data_collator if is_training else self.eval_data_collator
|
||||||
|
|
||||||
if dataset.column_names and "length" in dataset.column_names:
|
|
||||||
dataset = dataset.remove_columns(["length"])
|
|
||||||
if (
|
|
||||||
dataset.column_names
|
|
||||||
and "position_ids" in dataset.column_names
|
|
||||||
and "attention_mask" in dataset.column_names
|
|
||||||
and self.args.sample_packing
|
|
||||||
and self.args.sample_packing_drop_attention_mask
|
|
||||||
):
|
|
||||||
dataset = dataset.remove_columns(["attention_mask"])
|
|
||||||
|
|
||||||
if isinstance(dataset, datasets.Dataset):
|
if isinstance(dataset, datasets.Dataset):
|
||||||
if is_training:
|
if is_training:
|
||||||
if not self.args.sample_packing or self.args.pretraining:
|
if not self.args.sample_packing or self.args.pretraining:
|
||||||
@@ -294,6 +283,18 @@ class AxolotlTrainer(
|
|||||||
):
|
):
|
||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
|
|
||||||
|
if dataset.column_names and "length" in dataset.column_names:
|
||||||
|
dataset = dataset.remove_columns(["length"])
|
||||||
|
|
||||||
|
if (
|
||||||
|
dataset.column_names
|
||||||
|
and "position_ids" in dataset.column_names
|
||||||
|
and "attention_mask" in dataset.column_names
|
||||||
|
and self.args.sample_packing
|
||||||
|
and self.args.sample_packing_drop_attention_mask
|
||||||
|
):
|
||||||
|
dataset = dataset.remove_columns(["attention_mask"])
|
||||||
|
|
||||||
dataloader = DataLoader(dataset, **dataloader_params)
|
dataloader = DataLoader(dataset, **dataloader_params)
|
||||||
|
|
||||||
# Accelerator.free_memory() will destroy the references, so
|
# Accelerator.free_memory() will destroy the references, so
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ class GRPOStrategy:
|
|||||||
if trl.vllm_mode:
|
if trl.vllm_mode:
|
||||||
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
|
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
|
||||||
if trl.vllm_mode == "colocate":
|
if trl.vllm_mode == "colocate":
|
||||||
grpo_args_kwargs["enable_sleep_mode"] = trl.vllm_enable_sleep_mode # type: ignore[attr-defined]
|
grpo_args_kwargs["vllm_enable_sleep_mode"] = trl.vllm_enable_sleep_mode # type: ignore[attr-defined]
|
||||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||||
vllm_cfg.gpu_memory_utilization
|
vllm_cfg.gpu_memory_utilization
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import torch
|
|||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from .utils import create_bidirectional_attention_mask
|
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -360,7 +360,7 @@ def _diffusion_step(
|
|||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
outputs = model(input_ids=sequence, attention_mask=attention_mask)
|
outputs = model(input_ids=sequence, attention_mask=attention_mask)
|
||||||
logits = outputs.logits
|
logits = shift_logits_to_input_positions(outputs.logits)
|
||||||
|
|
||||||
# Only sample at currently masked positions
|
# Only sample at currently masked positions
|
||||||
if current_mask.any():
|
if current_mask.any():
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from .callbacks import DiffusionGenerationCallback
|
from .callbacks import DiffusionGenerationCallback
|
||||||
from .utils import create_bidirectional_attention_mask
|
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -207,7 +207,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
|||||||
input_ids=noisy_batch.long(),
|
input_ids=noisy_batch.long(),
|
||||||
attention_mask=bidirectional_mask,
|
attention_mask=bidirectional_mask,
|
||||||
)
|
)
|
||||||
logits = outputs.logits
|
logits = shift_logits_to_input_positions(outputs.logits)
|
||||||
|
|
||||||
if masked_indices.sum() > 0:
|
if masked_indices.sum() > 0:
|
||||||
valid_indices = torch.where(masked_indices)
|
valid_indices = torch.where(masked_indices)
|
||||||
|
|||||||
@@ -157,3 +157,10 @@ def create_bidirectional_attention_mask(
|
|||||||
|
|
||||||
# Add head dimension: [batch_size, 1, seq_len, seq_len]
|
# Add head dimension: [batch_size, 1, seq_len, seq_len]
|
||||||
return bidirectional_mask.unsqueeze(1)
|
return bidirectional_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def shift_logits_to_input_positions(logits: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Align next-token logits with their input token positions for diffusion."""
|
||||||
|
if logits.size(1) <= 1:
|
||||||
|
return logits
|
||||||
|
return torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
||||||
|
|||||||
@@ -517,9 +517,6 @@ class ModelLoader:
|
|||||||
if self.cfg.model_quantization_config_kwargs:
|
if self.cfg.model_quantization_config_kwargs:
|
||||||
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
|
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
|
||||||
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
|
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
|
||||||
else:
|
|
||||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
|
||||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
|
||||||
|
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
if not hasattr(self.model_config, "quantization_config"):
|
if not hasattr(self.model_config, "quantization_config"):
|
||||||
@@ -554,9 +551,7 @@ class ModelLoader:
|
|||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**self.model_config.quantization_config
|
**self.model_config.quantization_config
|
||||||
)
|
)
|
||||||
elif self.cfg.adapter == "qlora" and self.model_kwargs.get(
|
elif self.cfg.adapter == "qlora" and self.cfg.load_in_4bit:
|
||||||
"load_in_4bit", False
|
|
||||||
):
|
|
||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"llm_int8_threshold": 6.0,
|
"llm_int8_threshold": 6.0,
|
||||||
@@ -582,9 +577,7 @@ class ModelLoader:
|
|||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
elif self.cfg.adapter == "lora" and self.model_kwargs.get(
|
elif self.cfg.adapter == "lora" and self.cfg.load_in_8bit:
|
||||||
"load_in_8bit", False
|
|
||||||
):
|
|
||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
}
|
}
|
||||||
@@ -598,11 +591,6 @@ class ModelLoader:
|
|||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
|
||||||
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
|
|
||||||
self.model_kwargs.pop("load_in_8bit", None)
|
|
||||||
self.model_kwargs.pop("load_in_4bit", None)
|
|
||||||
|
|
||||||
def _set_attention_config(self):
|
def _set_attention_config(self):
|
||||||
"""Sample packing uses custom FA2 patch"""
|
"""Sample packing uses custom FA2 patch"""
|
||||||
if self.cfg.attn_implementation:
|
if self.cfg.attn_implementation:
|
||||||
|
|||||||
@@ -134,6 +134,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
|||||||
|
|
||||||
return Qwen2Attention
|
return Qwen2Attention
|
||||||
|
|
||||||
|
if model_type == "qwen3_vl":
|
||||||
|
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextAttention
|
||||||
|
|
||||||
|
return Qwen3VLTextAttention
|
||||||
|
|
||||||
if model_type == "mllama":
|
if model_type == "mllama":
|
||||||
from transformers.models.mllama.modeling_mllama import MllamaTextSelfAttention
|
from transformers.models.mllama.modeling_mllama import MllamaTextSelfAttention
|
||||||
|
|
||||||
|
|||||||
@@ -13,9 +13,7 @@ from axolotl.utils.logging import get_logger
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
|
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
|
||||||
PATCHED_GUARD = (
|
PATCHED_GUARD = 'if (attn_impl := (getattr(model.config, "_attn_implementation", None) or getattr(model.model.config, "_attn_implementation", None))) and attn_impl not in ("sdpa", "flash_attention_2"):'
|
||||||
'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_prepare_context_parallel_inputs() -> None:
|
def patch_prepare_context_parallel_inputs() -> None:
|
||||||
|
|||||||
@@ -71,10 +71,10 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
]
|
]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids_chosen": chosen_tokenized["input_ids"],
|
"chosen_input_ids": chosen_tokenized["input_ids"],
|
||||||
"attention_mask_chosen": chosen_tokenized["attention_mask"],
|
"attention_mask_chosen": chosen_tokenized["attention_mask"],
|
||||||
"labels_chosen": 1.0,
|
"labels_chosen": 1.0,
|
||||||
"input_ids_rejected": rejected_tokenized["input_ids"],
|
"rejected_input_ids": rejected_tokenized["input_ids"],
|
||||||
"attention_mask_rejected": rejected_tokenized["attention_mask"],
|
"attention_mask_rejected": rejected_tokenized["attention_mask"],
|
||||||
"labels_rejected": 0.0,
|
"labels_rejected": 0.0,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -120,3 +120,123 @@ def default(cfg, dataset_idx=0, **kwargs):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
return transform_fn, {"remove_columns": [field_messages]}
|
return transform_fn, {"remove_columns": [field_messages]}
|
||||||
|
|
||||||
|
|
||||||
|
def argilla_chat(cfg, dataset_idx=0, **kwargs):
|
||||||
|
"""
|
||||||
|
DPO chat template strategy for argilla-style datasets.
|
||||||
|
|
||||||
|
For argilla-style datasets where chosen/rejected contain full conversations
|
||||||
|
instead of single response messages. Extracts the conversation history from
|
||||||
|
the chosen field and formats both chosen/rejected responses using the
|
||||||
|
configured chat template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Configuration object containing chat_template and dataset settings
|
||||||
|
dataset_idx: Index of the dataset in the config (default: 0)
|
||||||
|
**kwargs: Additional keyword arguments (unused)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (transform_fn, dataset_kwargs) where:
|
||||||
|
- transform_fn: Function to transform dataset samples
|
||||||
|
- dataset_kwargs: Dict with 'remove_columns' specifying columns to drop
|
||||||
|
|
||||||
|
Dataset format:
|
||||||
|
{
|
||||||
|
"chosen": [
|
||||||
|
{"role": "user", "content": "..."},
|
||||||
|
{"role": "assistant", "content": "..."}
|
||||||
|
],
|
||||||
|
"rejected": [
|
||||||
|
{"role": "user", "content": "..."},
|
||||||
|
{"role": "assistant", "content": "..."}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
ds_cfg = cfg["datasets"][dataset_idx]
|
||||||
|
ds_cfg = handle_legacy_message_fields_logic(ds_cfg)
|
||||||
|
|
||||||
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||||
|
cfg=cfg, ds_cfg=ds_cfg
|
||||||
|
)
|
||||||
|
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||||
|
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||||
|
message_property_mappings = ds_cfg.get(
|
||||||
|
"message_property_mappings",
|
||||||
|
{
|
||||||
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
role_map_inv = ds_cfg.get(
|
||||||
|
"roles",
|
||||||
|
{
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
"system": ["system"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
role_map = {}
|
||||||
|
for target, sources in role_map_inv.items():
|
||||||
|
for source in sources:
|
||||||
|
role_map[source] = target
|
||||||
|
|
||||||
|
def transform_fn(sample, tokenizer=None):
|
||||||
|
chat_template_string = get_chat_template(
|
||||||
|
user_choice=chat_template_choice,
|
||||||
|
jinja_template=chat_template_jinja,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
chosen_raw = sample[field_chosen]
|
||||||
|
rejected_raw = sample[field_rejected]
|
||||||
|
|
||||||
|
# Extract messages (all but last) and responses (last message)
|
||||||
|
chosen_messages = [
|
||||||
|
{
|
||||||
|
"role": role_map[m[message_property_mappings["role"]]],
|
||||||
|
"content": m[message_property_mappings["content"]],
|
||||||
|
}
|
||||||
|
for m in chosen_raw[:-1]
|
||||||
|
]
|
||||||
|
chosen_response = {
|
||||||
|
"role": role_map[chosen_raw[-1][message_property_mappings["role"]]],
|
||||||
|
"content": chosen_raw[-1][message_property_mappings["content"]],
|
||||||
|
}
|
||||||
|
|
||||||
|
rejected_response = {
|
||||||
|
"role": role_map[rejected_raw[-1][message_property_mappings["role"]]],
|
||||||
|
"content": rejected_raw[-1][message_property_mappings["content"]],
|
||||||
|
}
|
||||||
|
|
||||||
|
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
result["prompt"] = tokenizer.apply_chat_template(
|
||||||
|
chosen_messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
chat_template=chat_template_string,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
result["chosen"] = tokenizer.apply_chat_template(
|
||||||
|
[dummy_user_message, chosen_response],
|
||||||
|
add_generation_prompt=False,
|
||||||
|
chat_template=chat_template_string,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
chosen_strip_index = result["chosen"].find(chosen_response["content"])
|
||||||
|
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
|
||||||
|
|
||||||
|
result["rejected"] = tokenizer.apply_chat_template(
|
||||||
|
[dummy_user_message, rejected_response],
|
||||||
|
add_generation_prompt=False,
|
||||||
|
chat_template=chat_template_string,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
rejected_strip_index = result["rejected"].find(rejected_response["content"])
|
||||||
|
result["rejected"] = result["rejected"][rejected_strip_index:].rstrip()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return transform_fn, {"remove_columns": [field_chosen, field_rejected]}
|
||||||
|
|||||||
@@ -17,6 +17,13 @@ def is_comet_available():
|
|||||||
return importlib.util.find_spec("comet_ml") is not None
|
return importlib.util.find_spec("comet_ml") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_opentelemetry_available():
|
||||||
|
return (
|
||||||
|
importlib.util.find_spec("opentelemetry") is not None
|
||||||
|
and importlib.util.find_spec("prometheus_client") is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_pytorch_version() -> tuple[int, int, int]:
|
def get_pytorch_version() -> tuple[int, int, int]:
|
||||||
"""
|
"""
|
||||||
Get Pytorch version as a tuple of (major, minor, patch).
|
Get Pytorch version as a tuple of (major, minor, patch).
|
||||||
|
|||||||
238
src/axolotl/utils/callbacks/opentelemetry.py
Normal file
238
src/axolotl/utils/callbacks/opentelemetry.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
"""OpenTelemetry metrics callback for Axolotl training"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
TrainerCallback,
|
||||||
|
TrainerControl,
|
||||||
|
TrainerState,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from opentelemetry import metrics
|
||||||
|
from opentelemetry.exporter.prometheus import PrometheusMetricReader
|
||||||
|
from opentelemetry.metrics import set_meter_provider
|
||||||
|
from opentelemetry.sdk.metrics import MeterProvider as SDKMeterProvider
|
||||||
|
from prometheus_client import start_http_server
|
||||||
|
|
||||||
|
OPENTELEMETRY_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning("OpenTelemetry not available. pip install [opentelemetry]")
|
||||||
|
OPENTELEMETRY_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
class OpenTelemetryMetricsCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
TrainerCallback that exports training metrics to OpenTelemetry/Prometheus.
|
||||||
|
|
||||||
|
This callback automatically tracks key training metrics including:
|
||||||
|
- Training loss
|
||||||
|
- Evaluation loss
|
||||||
|
- Learning rate
|
||||||
|
- Epoch progress
|
||||||
|
- Global step count
|
||||||
|
- Gradient norm
|
||||||
|
|
||||||
|
Metrics are exposed via HTTP endpoint for Prometheus scraping.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
if not OPENTELEMETRY_AVAILABLE:
|
||||||
|
LOG.warning("OpenTelemetry not available, metrics will not be collected")
|
||||||
|
self.metrics_enabled = False
|
||||||
|
return
|
||||||
|
|
||||||
|
self.cfg = cfg
|
||||||
|
self.metrics_host = getattr(cfg, "otel_metrics_host", "localhost")
|
||||||
|
self.metrics_port = getattr(cfg, "otel_metrics_port", 8000)
|
||||||
|
self.metrics_enabled = True
|
||||||
|
self.server_started = False
|
||||||
|
self.metrics_lock = threading.Lock()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create Prometheus metrics reader
|
||||||
|
prometheus_reader = PrometheusMetricReader()
|
||||||
|
|
||||||
|
# Create meter provider with Prometheus exporter
|
||||||
|
provider = SDKMeterProvider(metric_readers=[prometheus_reader])
|
||||||
|
set_meter_provider(provider)
|
||||||
|
|
||||||
|
# Get meter for creating metrics
|
||||||
|
self.meter = metrics.get_meter("axolotl.training")
|
||||||
|
|
||||||
|
# Create metrics
|
||||||
|
self._create_metrics()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOG.warning(f"Failed to initialize OpenTelemetry metrics: {e}")
|
||||||
|
self.metrics_enabled = False
|
||||||
|
|
||||||
|
def _create_metrics(self):
|
||||||
|
"""Create all metrics that will be tracked"""
|
||||||
|
self.train_loss_gauge = self.meter.create_gauge(
|
||||||
|
name="axolotl_train_loss",
|
||||||
|
description="Current training loss",
|
||||||
|
unit="1",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.eval_loss_gauge = self.meter.create_gauge(
|
||||||
|
name="axolotl_eval_loss",
|
||||||
|
description="Current evaluation loss",
|
||||||
|
unit="1",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.learning_rate_gauge = self.meter.create_gauge(
|
||||||
|
name="axolotl_learning_rate",
|
||||||
|
description="Current learning rate",
|
||||||
|
unit="1",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.epoch_gauge = self.meter.create_gauge(
|
||||||
|
name="axolotl_epoch",
|
||||||
|
description="Current training epoch",
|
||||||
|
unit="1",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.global_step_counter = self.meter.create_counter(
|
||||||
|
name="axolotl_global_steps",
|
||||||
|
description="Total training steps completed",
|
||||||
|
unit="1",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.grad_norm_gauge = self.meter.create_gauge(
|
||||||
|
name="axolotl_gradient_norm",
|
||||||
|
description="Gradient norm",
|
||||||
|
unit="1",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.memory_usage_gauge = self.meter.create_gauge(
|
||||||
|
name="axolotl_memory_usage",
|
||||||
|
description="Current memory usage in MB",
|
||||||
|
unit="MB",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _start_metrics_server(self):
|
||||||
|
"""Start the HTTP server for metrics exposure"""
|
||||||
|
if self.server_started:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_http_server(self.metrics_port, addr=self.metrics_host)
|
||||||
|
self.server_started = True
|
||||||
|
LOG.info(
|
||||||
|
f"OpenTelemetry metrics server started on http://{self.metrics_host}:{self.metrics_port}/metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOG.error(f"Failed to start OpenTelemetry metrics server: {e}")
|
||||||
|
|
||||||
|
def on_train_begin(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Called at the beginning of training"""
|
||||||
|
if not self.metrics_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._start_metrics_server()
|
||||||
|
LOG.info("OpenTelemetry metrics collection started")
|
||||||
|
|
||||||
|
def on_log(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
logs: Optional[Dict[str, float]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Called when logging occurs"""
|
||||||
|
if not self.metrics_enabled or not logs:
|
||||||
|
return
|
||||||
|
|
||||||
|
if "loss" in logs:
|
||||||
|
self.train_loss_gauge.set(logs["loss"])
|
||||||
|
|
||||||
|
if "eval_loss" in logs:
|
||||||
|
self.eval_loss_gauge.set(logs["eval_loss"])
|
||||||
|
|
||||||
|
if "learning_rate" in logs:
|
||||||
|
self.learning_rate_gauge.set(logs["learning_rate"])
|
||||||
|
|
||||||
|
if "epoch" in logs:
|
||||||
|
self.epoch_gauge.set(logs["epoch"])
|
||||||
|
|
||||||
|
if "grad_norm" in logs:
|
||||||
|
self.grad_norm_gauge.set(logs["grad_norm"])
|
||||||
|
if "memory_usage" in logs:
|
||||||
|
self.memory_usage_gauge.set(logs["memory_usage"])
|
||||||
|
|
||||||
|
def on_step_end(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Called at the end of each training step"""
|
||||||
|
if not self.metrics_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Update step counter and epoch
|
||||||
|
self.global_step_counter.add(1)
|
||||||
|
if state.epoch is not None:
|
||||||
|
self.epoch_gauge.set(state.epoch)
|
||||||
|
|
||||||
|
def on_evaluate(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
metrics: Optional[Dict[str, float]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Called after evaluation"""
|
||||||
|
if not self.metrics_enabled or not metrics:
|
||||||
|
return
|
||||||
|
|
||||||
|
if "eval_loss" in metrics:
|
||||||
|
self.eval_loss_gauge.set(metrics["eval_loss"])
|
||||||
|
|
||||||
|
# Record any other eval metrics as gauges
|
||||||
|
for key, value in metrics.items():
|
||||||
|
if key.startswith("eval_") and isinstance(value, (int, float)):
|
||||||
|
# Create gauge for this metric if it doesn't exist
|
||||||
|
gauge_name = f"axolotl_{key}"
|
||||||
|
try:
|
||||||
|
gauge = self.meter.create_gauge(
|
||||||
|
name=gauge_name,
|
||||||
|
description=f"Evaluation metric: {key}",
|
||||||
|
unit="1",
|
||||||
|
)
|
||||||
|
gauge.set(value)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.warning(f"Failed to create/update metric {gauge_name}: {e}")
|
||||||
|
|
||||||
|
def on_train_end(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Called at the end of training"""
|
||||||
|
if not self.metrics_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
LOG.info("Training completed. OpenTelemetry metrics collection finished.")
|
||||||
|
LOG.info(
|
||||||
|
f"Metrics are still available at http://{self.metrics_host}:{self.metrics_port}/metrics"
|
||||||
|
)
|
||||||
@@ -239,6 +239,11 @@ def _load_from_local_path(
|
|||||||
return load_dataset(dataset_config.path, **load_dataset_kwargs)
|
return load_dataset(dataset_config.path, **load_dataset_kwargs)
|
||||||
elif local_path.is_file():
|
elif local_path.is_file():
|
||||||
dataset_type = get_dataset_type(dataset_config)
|
dataset_type = get_dataset_type(dataset_config)
|
||||||
|
|
||||||
|
# For single file datasets, HF always creates only a "train" split
|
||||||
|
if dataset_type in ("json", "csv", "text"):
|
||||||
|
load_dataset_kwargs["split"] = "train"
|
||||||
|
|
||||||
return load_dataset(
|
return load_dataset(
|
||||||
dataset_type,
|
dataset_type,
|
||||||
data_files=dataset_config.path,
|
data_files=dataset_config.path,
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from axolotl.utils.schemas.integrations import (
|
|||||||
GradioConfig,
|
GradioConfig,
|
||||||
LISAConfig,
|
LISAConfig,
|
||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
|
OpenTelemetryConfig,
|
||||||
RayConfig,
|
RayConfig,
|
||||||
WandbConfig,
|
WandbConfig,
|
||||||
)
|
)
|
||||||
@@ -60,6 +61,7 @@ class AxolotlInputConfig(
|
|||||||
WandbConfig,
|
WandbConfig,
|
||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
CometConfig,
|
CometConfig,
|
||||||
|
OpenTelemetryConfig,
|
||||||
LISAConfig,
|
LISAConfig,
|
||||||
GradioConfig,
|
GradioConfig,
|
||||||
RayConfig,
|
RayConfig,
|
||||||
|
|||||||
@@ -176,3 +176,27 @@ class RayConfig(BaseModel):
|
|||||||
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
|
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenTelemetryConfig(BaseModel):
|
||||||
|
"""OpenTelemetry configuration subset"""
|
||||||
|
|
||||||
|
use_otel_metrics: bool | None = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Enable OpenTelemetry metrics collection and Prometheus export"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
otel_metrics_host: str | None = Field(
|
||||||
|
default="localhost",
|
||||||
|
json_schema_extra={
|
||||||
|
"title": "OpenTelemetry Metrics Host",
|
||||||
|
"description": "Host to bind the OpenTelemetry metrics server to",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
otel_metrics_port: int | None = Field(
|
||||||
|
default=8000,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Port for the Prometheus metrics HTTP server"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
@@ -546,7 +546,6 @@ class TestMultiGPULlama:
|
|||||||
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skip("regression failure from v4.57.0")
|
|
||||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import pytest
|
|||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from axolotl.prompt_strategies.dpo.chat_template import default
|
from axolotl.prompt_strategies.dpo.chat_template import argilla_chat, default
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
@@ -78,6 +78,36 @@ def fixture_custom_assistant_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="argilla_chat_dataset")
|
||||||
|
def fixture_argilla_chat_dataset():
|
||||||
|
return Dataset.from_list(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"chosen": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "goodbye",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"rejected": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "party on",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="phi3_tokenizer")
|
@pytest.fixture(name="phi3_tokenizer")
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def fixture_phi3_tokenizer():
|
def fixture_phi3_tokenizer():
|
||||||
@@ -216,5 +246,51 @@ class TestAssistantDPOChatTemplateGemma:
|
|||||||
assert result["rejected"] == "party on<end_of_turn>"
|
assert result["rejected"] == "party on<end_of_turn>"
|
||||||
|
|
||||||
|
|
||||||
|
class TestArgillaChatDPOChatTemplate:
|
||||||
|
"""
|
||||||
|
Test class for argilla_chat style datasets (chosen/rejected contain full conversations).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_llama3_argilla_chat(self, llama3_tokenizer, argilla_chat_dataset):
|
||||||
|
transform_fn, _ = argilla_chat(
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"type": "chat_template.argilla_chat",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = transform_fn(argilla_chat_dataset[0], tokenizer=llama3_tokenizer)
|
||||||
|
assert result["prompt"] == (
|
||||||
|
"<|begin_of_text|>"
|
||||||
|
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
assert result["chosen"] == "goodbye<|eot_id|>"
|
||||||
|
assert result["rejected"] == "party on<|eot_id|>"
|
||||||
|
|
||||||
|
def test_phi3_argilla_chat(self, phi3_tokenizer, argilla_chat_dataset):
|
||||||
|
transform_fn, _ = argilla_chat(
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"chat_template": "tokenizer_default",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"type": "chat_template.argilla_chat",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = transform_fn(argilla_chat_dataset[0], tokenizer=phi3_tokenizer)
|
||||||
|
assert result["prompt"] == "<|user|>\nhello<|end|>\n" + "<|assistant|>\n"
|
||||||
|
assert result["chosen"] == "goodbye<|end|>"
|
||||||
|
assert result["rejected"] == "party on<|end|>"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -80,16 +80,26 @@ class TestModelsUtils:
|
|||||||
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
|
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
|
||||||
and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
|
and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
|
||||||
)
|
)
|
||||||
elif load_in_8bit and self.cfg.adapter is not None:
|
|
||||||
assert self.model_loader.model_kwargs["load_in_8bit"]
|
|
||||||
elif load_in_4bit and self.cfg.adapter is not None:
|
|
||||||
assert self.model_loader.model_kwargs["load_in_4bit"]
|
|
||||||
|
|
||||||
if (self.cfg.adapter == "qlora" and load_in_4bit) or (
|
if self.cfg.adapter == "qlora" and load_in_4bit:
|
||||||
self.cfg.adapter == "lora" and load_in_8bit
|
assert isinstance(
|
||||||
):
|
self.model_loader.model_kwargs.get("quantization_config"),
|
||||||
assert self.model_loader.model_kwargs.get(
|
BitsAndBytesConfig,
|
||||||
"quantization_config", BitsAndBytesConfig
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.model_loader.model_kwargs["quantization_config"]._load_in_4bit
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
if self.cfg.adapter == "lora" and load_in_8bit:
|
||||||
|
assert isinstance(
|
||||||
|
self.model_loader.model_kwargs.get("quantization_config"),
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.model_loader.model_kwargs["quantization_config"]._load_in_8bit
|
||||||
|
is True
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_message_property_mapping(self):
|
def test_message_property_mapping(self):
|
||||||
|
|||||||
349
tests/test_opentelemetry_callback.py
Normal file
349
tests/test_opentelemetry_callback.py
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
"""Tests for OpenTelemetry metrics callback functionality."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_otel_config():
|
||||||
|
"""Mock configuration for OpenTelemetry callback."""
|
||||||
|
return DictDefault(
|
||||||
|
{
|
||||||
|
"use_otel_metrics": True,
|
||||||
|
"otel_metrics_host": "localhost",
|
||||||
|
"otel_metrics_port": 8003, # Use unique port for tests
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_trainer_state():
|
||||||
|
"""Mock trainer state for callback testing."""
|
||||||
|
from transformers import TrainerState
|
||||||
|
|
||||||
|
state = TrainerState()
|
||||||
|
state.epoch = 1.0
|
||||||
|
state.global_step = 100
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_training_args():
|
||||||
|
"""Mock training arguments for callback testing."""
|
||||||
|
from transformers import TrainingArguments
|
||||||
|
|
||||||
|
return TrainingArguments(output_dir="/tmp/test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_trainer_control():
|
||||||
|
"""Mock trainer control for callback testing."""
|
||||||
|
from transformers.trainer_callback import TrainerControl
|
||||||
|
|
||||||
|
return TrainerControl()
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenTelemetryConfig:
|
||||||
|
"""Test OpenTelemetry configuration schema."""
|
||||||
|
|
||||||
|
def test_config_schema_valid(self):
|
||||||
|
"""Test OpenTelemetry configuration schema validation."""
|
||||||
|
from axolotl.utils.schemas.integrations import OpenTelemetryConfig
|
||||||
|
|
||||||
|
# Test valid config
|
||||||
|
valid_config = {
|
||||||
|
"use_otel_metrics": True,
|
||||||
|
"otel_metrics_host": "localhost",
|
||||||
|
"otel_metrics_port": 8000,
|
||||||
|
}
|
||||||
|
|
||||||
|
otel_config = OpenTelemetryConfig(**valid_config)
|
||||||
|
assert otel_config.use_otel_metrics is True
|
||||||
|
assert otel_config.otel_metrics_host == "localhost"
|
||||||
|
assert otel_config.otel_metrics_port == 8000
|
||||||
|
|
||||||
|
def test_config_defaults(self):
|
||||||
|
"""Test OpenTelemetry configuration default values."""
|
||||||
|
from axolotl.utils.schemas.integrations import OpenTelemetryConfig
|
||||||
|
|
||||||
|
# Test minimal config with defaults
|
||||||
|
minimal_config = {"use_otel_metrics": True}
|
||||||
|
|
||||||
|
otel_config = OpenTelemetryConfig(**minimal_config)
|
||||||
|
assert otel_config.use_otel_metrics is True
|
||||||
|
assert otel_config.otel_metrics_host == "localhost" # default
|
||||||
|
assert otel_config.otel_metrics_port == 8000 # default
|
||||||
|
|
||||||
|
def test_config_disabled_by_default(self):
|
||||||
|
"""Test that OpenTelemetry is disabled by default."""
|
||||||
|
from axolotl.utils.schemas.integrations import OpenTelemetryConfig
|
||||||
|
|
||||||
|
# Test default config
|
||||||
|
default_config = OpenTelemetryConfig()
|
||||||
|
assert default_config.use_otel_metrics is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenTelemetryCallback:
|
||||||
|
"""Test OpenTelemetry callback functionality."""
|
||||||
|
|
||||||
|
def test_callback_import(self):
|
||||||
|
"""Test that OpenTelemetry callback can be imported."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import OpenTelemetryMetricsCallback
|
||||||
|
|
||||||
|
assert OpenTelemetryMetricsCallback is not None
|
||||||
|
|
||||||
|
def test_callback_graceful_fallback(self, mock_otel_config):
|
||||||
|
"""Test callback gracefully handles missing dependencies."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import OpenTelemetryMetricsCallback
|
||||||
|
|
||||||
|
# This should not raise an exception even if dependencies are missing
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
|
||||||
|
# Callback should exist but may have metrics disabled
|
||||||
|
assert callback is not None
|
||||||
|
assert hasattr(callback, "metrics_enabled")
|
||||||
|
|
||||||
|
def test_callback_initialization_enabled(self, mock_otel_config):
|
||||||
|
"""Test callback initialization when OpenTelemetry is available."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OPENTELEMETRY_AVAILABLE,
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
|
||||||
|
if OPENTELEMETRY_AVAILABLE:
|
||||||
|
assert callback.metrics_enabled is True
|
||||||
|
assert callback.cfg == mock_otel_config
|
||||||
|
assert callback.metrics_host == "localhost"
|
||||||
|
assert callback.metrics_port == 8003
|
||||||
|
else:
|
||||||
|
assert callback.metrics_enabled is False
|
||||||
|
|
||||||
|
def test_metrics_server_lifecycle(
|
||||||
|
self,
|
||||||
|
mock_otel_config,
|
||||||
|
mock_trainer_state,
|
||||||
|
mock_training_args,
|
||||||
|
mock_trainer_control,
|
||||||
|
):
|
||||||
|
"""Test metrics server starts and stops correctly."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OPENTELEMETRY_AVAILABLE,
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not OPENTELEMETRY_AVAILABLE:
|
||||||
|
pytest.skip("OpenTelemetry dependencies not available")
|
||||||
|
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
|
||||||
|
# Start server
|
||||||
|
callback.on_train_begin(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
|
assert callback.server_started is True
|
||||||
|
|
||||||
|
# End training
|
||||||
|
callback.on_train_end(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_metrics_recording(
|
||||||
|
self,
|
||||||
|
mock_otel_config,
|
||||||
|
mock_trainer_state,
|
||||||
|
mock_training_args,
|
||||||
|
mock_trainer_control,
|
||||||
|
):
|
||||||
|
"""Test that metrics are recorded during training."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OPENTELEMETRY_AVAILABLE,
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not OPENTELEMETRY_AVAILABLE:
|
||||||
|
pytest.skip("OpenTelemetry dependencies not available")
|
||||||
|
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
callback.on_train_begin(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test logging metrics
|
||||||
|
test_logs = {
|
||||||
|
"loss": 0.5,
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"grad_norm": 0.8,
|
||||||
|
}
|
||||||
|
|
||||||
|
# This should not raise an exception
|
||||||
|
callback.on_log(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control, logs=test_logs
|
||||||
|
)
|
||||||
|
assert callback.metrics_enabled is True
|
||||||
|
|
||||||
|
def test_evaluation_metrics(
|
||||||
|
self,
|
||||||
|
mock_otel_config,
|
||||||
|
mock_trainer_state,
|
||||||
|
mock_training_args,
|
||||||
|
mock_trainer_control,
|
||||||
|
):
|
||||||
|
"""Test evaluation metrics recording."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OPENTELEMETRY_AVAILABLE,
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not OPENTELEMETRY_AVAILABLE:
|
||||||
|
pytest.skip("OpenTelemetry dependencies not available")
|
||||||
|
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
callback.on_train_begin(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test evaluation metrics
|
||||||
|
eval_logs = {
|
||||||
|
"eval_loss": 0.3,
|
||||||
|
"eval_accuracy": 0.95,
|
||||||
|
}
|
||||||
|
|
||||||
|
# This should not raise an exception
|
||||||
|
callback.on_evaluate(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control, eval_logs
|
||||||
|
)
|
||||||
|
assert callback.metrics_enabled is True
|
||||||
|
|
||||||
|
def test_thread_safety(self, mock_otel_config):
|
||||||
|
"""Test that callback has thread safety mechanisms."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OPENTELEMETRY_AVAILABLE,
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not OPENTELEMETRY_AVAILABLE:
|
||||||
|
pytest.skip("OpenTelemetry dependencies not available")
|
||||||
|
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
assert hasattr(callback, "metrics_lock")
|
||||||
|
# Check it's a lock-like object
|
||||||
|
assert hasattr(callback.metrics_lock, "__enter__")
|
||||||
|
assert hasattr(callback.metrics_lock, "__exit__")
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenTelemetryIntegration:
|
||||||
|
"""Integration tests for OpenTelemetry."""
|
||||||
|
|
||||||
|
def test_availability_check(self):
|
||||||
|
"""Test availability check function."""
|
||||||
|
from axolotl.utils import is_opentelemetry_available
|
||||||
|
|
||||||
|
result = is_opentelemetry_available()
|
||||||
|
assert isinstance(result, bool)
|
||||||
|
|
||||||
|
def test_prometheus_endpoint_basic(
|
||||||
|
self,
|
||||||
|
mock_otel_config,
|
||||||
|
mock_trainer_state,
|
||||||
|
mock_training_args,
|
||||||
|
mock_trainer_control,
|
||||||
|
):
|
||||||
|
"""Test basic Prometheus endpoint functionality."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OPENTELEMETRY_AVAILABLE,
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not OPENTELEMETRY_AVAILABLE:
|
||||||
|
pytest.skip("OpenTelemetry dependencies not available")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("requests library not available")
|
||||||
|
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
callback.on_train_begin(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
|
|
||||||
|
if not callback.server_started:
|
||||||
|
pytest.skip("Metrics server failed to start")
|
||||||
|
|
||||||
|
# Give server time to start
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Try to access metrics endpoint
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
f"http://{callback.metrics_host}:{callback.metrics_port}/metrics",
|
||||||
|
timeout=2,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
# Check for Prometheus format
|
||||||
|
assert "# TYPE" in response.text or "# HELP" in response.text
|
||||||
|
except requests.exceptions.RequestException:
|
||||||
|
pytest.skip(
|
||||||
|
"Could not connect to metrics endpoint - this is expected in some environments"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenTelemetryCallbackMethods:
|
||||||
|
"""Test specific callback methods."""
|
||||||
|
|
||||||
|
def test_step_end_callback(
|
||||||
|
self,
|
||||||
|
mock_otel_config,
|
||||||
|
mock_trainer_state,
|
||||||
|
mock_training_args,
|
||||||
|
mock_trainer_control,
|
||||||
|
):
|
||||||
|
"""Test step end callback method."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OPENTELEMETRY_AVAILABLE,
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not OPENTELEMETRY_AVAILABLE:
|
||||||
|
pytest.skip("OpenTelemetry dependencies not available")
|
||||||
|
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
callback.on_train_begin(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise an exception
|
||||||
|
callback.on_step_end(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_epoch_end_callback(
|
||||||
|
self,
|
||||||
|
mock_otel_config,
|
||||||
|
mock_trainer_state,
|
||||||
|
mock_training_args,
|
||||||
|
mock_trainer_control,
|
||||||
|
):
|
||||||
|
"""Test epoch end callback method."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OPENTELEMETRY_AVAILABLE,
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not OPENTELEMETRY_AVAILABLE:
|
||||||
|
pytest.skip("OpenTelemetry dependencies not available")
|
||||||
|
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
callback.on_train_begin(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise an exception
|
||||||
|
callback.on_epoch_end(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user