Compare commits

..

8 Commits

Author SHA1 Message Date
Dan Saunders
c3e1882de5 progress 2025-08-22 02:43:16 -04:00
Dan Saunders
889b27ecf1 tui 2025-08-22 05:08:02 +00:00
Wing Lian
0fa752e58b upgrade flash-attn to 2.8.3 for gpt-oss attn sink support (#3082) 2025-08-21 15:04:10 -04:00
Dan Saunders
08e517ea48 Update .coderabbit.yaml (#3091) [skip ci] 2025-08-20 22:14:13 -04:00
Wing Lian
07fd22f39b better handling of lora w bias with fsdp2 and handling of files when saving model checkpoint (#3090) 2025-08-20 15:17:48 -04:00
Wing Lian
06eaf6c448 misc fixes (#3085) 2025-08-20 08:52:26 -04:00
goggle
050210e637 fix: Sweep runs overwrite each other because output_dir from base config is reused (#3080)
* refactor: improve output_dir handling in generate_config_files

* fix typo

* cli: harden sweep output_dir handling with base fallback

- Ensure sweep permutations always resolve a valid output_dir
- Default to ./model-out if neither permutation nor base config sets output_dir
- Append sweepXXXX suffix consistently for each permutation
- Prevent Path(None) TypeError and improve robustness of sweep config generation

* fix typo

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-08-19 20:25:20 -04:00
Wing Lian
05cedbfb1e add baseten info for gpt-oss recipe (#3078)
* add bsaeten info for gpt-oss recipe

* incorporate PR review
2025-08-19 13:30:37 -04:00
50 changed files with 13232 additions and 11816 deletions

View File

@@ -12,5 +12,6 @@ reviews:
auto_review: auto_review:
enabled: true enabled: true
drafts: false drafts: false
auto_incremental_review: true
chat: chat:
auto_reply: true auto_reply: true

File diff suppressed because it is too large Load Diff

View File

@@ -41,6 +41,12 @@ model, and final model output, you may need at least 3TB of free disk space to k
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
``` ```
To simplify fine-tuning across 2 nodes × 8x H100 (80GB) GPUs, we've partnered with [Baseten](https://baseten.co) to showcase multi-node
training of the 120B model using Baseten Truss. You can read more about this recipe on
[Baseten's blog](https://www.baseten.co/blog/how-to-fine-tune-gpt-oss-120b-with-baseten-and-axolotl/). The recipe can
be found on their
[GitHub](https://github.com/basetenlabs/ml-cookbook/tree/main/examples/oss-gpt-120b-axolotl/training).
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`. ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue. See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
@@ -61,9 +67,23 @@ mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
### Inferencing your fine-tuned model ### Inferencing your fine-tuned model
#### vLLM
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425 GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
for more information about using a special vllm-openai docker image for inferencing with vLLM. for more information about using a special vllm-openai docker image for inferencing with vLLM.
Optionally, vLLM can be installed from nightly:
```bash
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
```
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
```bash
vllm serve ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-20b --host 0.0.0.0 --port 8888 --tensor-parallel-size 8
```
#### SGLang
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server: SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:

View File

@@ -44,7 +44,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -40,7 +40,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking field_thinking: thinking
template_thinking_key: thinking template_thinking_key: thinking
dataset_prepared_path: last_run_prepared dataset_prepared_path: ./outputs/last_run_prepared
val_set_size: 0 val_set_size: 0
output_dir: ./outputs/gpt-oss-out/ output_dir: ./outputs/gpt-oss-out/
@@ -41,7 +41,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking field_thinking: thinking
template_thinking_key: thinking template_thinking_key: thinking
dataset_prepared_path: last_run_prepared dataset_prepared_path: ./outputs/last_run_prepared
val_set_size: 0 val_set_size: 0
output_dir: ./outputs/gpt-oss-out/ output_dir: ./outputs/gpt-oss-out/
@@ -40,7 +40,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -53,7 +53,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -1,57 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
pretraining_dataset:
- path: wikitext
name: wikitext-103-raw-v1
type: completion
field: text
plugins:
- diffusion.DiffusionPlugin
noise_schedule: cosine
min_mask_ratio: 0.15
max_mask_ratio: 0.85
eps: 5e-4
importance_weighting: true
mask_token_id: 128002
generate_samples: true
generation_interval: 10
output_dir: ./outputs/model-out
sequence_len: 512
sample_packing: true
gradient_accumulation_steps: 8
micro_batch_size: 4
max_steps: 10000
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 3e-4
bf16: auto
tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
sdp_attention: true
warmup_steps: 1000
save_strategy: steps
save_steps: 1000
special_tokens:
pad_token: "<|end_of_text|>"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,58 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
val_set_size: 0.05
plugins:
- diffusion.DiffusionPlugin
noise_schedule: cosine
min_mask_ratio: 0.1
max_mask_ratio: 0.9
num_diffusion_steps: 128
eps: 1e-3
importance_weighting: true
mask_token_id: 128002
output_dir: ./outputs/model-out
sequence_len: 512
sample_packing: true
eval_sample_packing: true
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 1e-5
bf16: auto
tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
sdp_attention: true
warmup_steps: 1000
save_strategy: steps
eval_strategy: steps
save_steps: 500
eval_steps: 500
special_tokens:
pad_token: "<|end_of_text|>"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -72,3 +72,8 @@ axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.5 axolotl-contribs-mit==0.0.5
mistral-common==1.8.3 mistral-common==1.8.3
# TUI dependencies
textual==1.0.0
rich==14.1.0
tree_sitter_ruby==0.23.1

View File

@@ -118,9 +118,9 @@ def get_package_version():
extras_require = { extras_require = {
"flash-attn": ["flash-attn==2.8.2"], "flash-attn": ["flash-attn==2.8.3"],
"ring-flash-attn": [ "ring-flash-attn": [
"flash-attn==2.8.2", "flash-attn==2.8.3",
"ring-flash-attn>=0.1.7", "ring-flash-attn>=0.1.7",
"yunchang==0.6.0", "yunchang==0.6.0",
], ],

View File

@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
return res return res
def get_image(self): def get_image(self):
docker_tag = "main-py3.11-cu124-2.6.0" docker_tag = "main-py3.11-cu126-2.7.1"
if self.config.docker_tag: if self.config.docker_tag:
docker_tag = self.config.docker_tag docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}" docker_image = f"axolotlai/axolotl:{docker_tag}"
@@ -200,7 +200,7 @@ class ModalCloud(Cloud):
if family in ["a10", "a10g"]: if family in ["a10", "a10g"]:
return modal.gpu.A10G(count=count) return modal.gpu.A10G(count=count)
if family == "h100": if family == "h100":
return modal.gpu.H100(count=count) return f"H100:{count}"
if family == "t4": if family == "t4":
return modal.gpu.T4(count=count) return modal.gpu.T4(count=count)
if family == "l4": if family == "l4":

View File

@@ -64,7 +64,7 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter importlib.import_module("axolotl.prompters"), prompter
) )
elif cfg.chat_template: elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template) chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
elif cfg.datasets[0].type == "chat_template": elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config( chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer

View File

@@ -344,6 +344,26 @@ def delinearize_llama4(model: str, output: str):
cli.add_command(lm_eval) cli.add_command(lm_eval)
@cli.command()
def tui():
"""
Launch the Axolotl Terminal User Interface (TUI).
Provides an interactive interface for configuration management,
training monitoring, dataset handling, and model operations.
"""
try:
from axolotl.tui.app import run
run()
except ImportError:
click.echo(
"TUI dependencies not installed. Install with: pip install textual rich"
)
except Exception as e:
click.echo(f"Error launching TUI: {e}")
def main(): def main():
cli() cli()

View File

@@ -97,7 +97,8 @@ def do_cli(
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
os.environ["AXOLOTL_IS_PREPROCESS"] = "1" os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
parsed_cfg = load_cfg(config, **kwargs) is_preprocess = kwargs.pop("is_preprocess", True)
parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs)
parsed_cfg.is_preprocess = True parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs) parser = transformers.HfArgumentParser(PreprocessCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -3,11 +3,12 @@
import random import random
from copy import deepcopy from copy import deepcopy
from itertools import product from itertools import product
from typing import Any
def generate_sweep_configs( def generate_sweep_configs(
base_config: dict[str, list], sweeps_config: dict[str, list] base_config: dict[str, list], sweeps_config: dict[str, list]
) -> list[dict[str, list]]: ) -> list[dict[str, Any]]:
""" """
Recursively generates all possible configurations by applying sweeps to the base config. Recursively generates all possible configurations by applying sweeps to the base config.

View File

@@ -4,6 +4,7 @@ import os
import subprocess # nosec import subprocess # nosec
import sys import sys
import tempfile import tempfile
from pathlib import Path
from typing import Any, Iterator, Literal from typing import Any, Iterator, Literal
import yaml import yaml
@@ -88,7 +89,12 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str,
# Generate all possible configurations # Generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config) permutations = generate_sweep_configs(base_config, sweep_config)
is_group = len(permutations) > 1 is_group = len(permutations) > 1
for permutation in permutations: base_output_dir = base_config.get("output_dir", "./model-out")
for idx, permutation in enumerate(permutations, start=1):
permutation_dir = Path(permutation.get("output_dir", base_output_dir))
permutation_id = f"sweep{idx:04d}"
permutation["output_dir"] = str(permutation_dir / permutation_id)
# pylint: disable=consider-using-with # pylint: disable=consider-using-with
temp_file = tempfile.NamedTemporaryFile( temp_file = tempfile.NamedTemporaryFile(
mode="w", mode="w",

View File

@@ -10,7 +10,6 @@ import transformers
from transformers import ( from transformers import (
DataCollatorWithFlattening, DataCollatorWithFlattening,
EarlyStoppingCallback, EarlyStoppingCallback,
Trainer,
) )
from trl.trainer.utils import RewardDataCollatorWithPadding from trl.trainer.utils import RewardDataCollatorWithPadding
@@ -386,11 +385,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**data_collator_kwargs, **data_collator_kwargs,
) )
sig = inspect.signature(trainer_cls) sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters or issubclass(trainer_cls, Trainer): if "processing_class" in sig.parameters:
trainer_kwargs["processing_class"] = self.tokenizer trainer_kwargs["processing_class"] = self.tokenizer
elif "tokenizer" in sig.parameters: elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer trainer_kwargs["tokenizer"] = self.tokenizer
if ( if (
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer] trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
and self.cfg.datasets is not None and self.cfg.datasets is not None

View File

@@ -82,9 +82,7 @@ class AxolotlTrainer(
super().__init__(*_args, **kwargs) super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict( self._stored_metrics = defaultdict(lambda: defaultdict(list))
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
)
if self.args.orpo_alpha: if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
@@ -575,26 +573,9 @@ class AxolotlTrainer(
""" """
# logs either has 'loss' or 'eval_loss' # logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval" train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
# Add reduced stored metrics to logs for key, metrics in self._stored_metrics[train_eval].items():
for key, metric_data in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item()
values = torch.tensor(metric_data["values"])
reduction_type = metric_data["reduction"]
if reduction_type == "mean":
logs[key] = values.mean().item()
elif reduction_type == "min":
logs[key] = values.min().item()
elif reduction_type == "max":
logs[key] = values.max().item()
elif reduction_type == "sum":
logs[key] = values.sum().item()
else:
raise NotImplementedError(
"Metric reduction must be one of [mean, min, max, sum]"
)
logs[key] = round(logs[key], 4)
if is_main_process(): if is_main_process():
# Add memory usage # Add memory usage
@@ -611,27 +592,10 @@ class AxolotlTrainer(
return super().log(logs, start_time) return super().log(logs, start_time)
def store_metrics( def store_metrics(
self, self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
metrics: dict[str, float] | dict[str, tuple[int | float, str]],
train_eval: Literal["train", "eval"] = "train",
reduction: Literal["mean", "min", "max", "sum"] = "mean",
) -> None: ) -> None:
"""
Store metrics with specified reduction type.
Args:
metrics: Dictionary of metric names to values, or metric names to (value,
reduction_type) tuples.
train_eval: Whether this is for training or evaluation.
"""
for key, value in metrics.items(): for key, value in metrics.items():
if isinstance(value, tuple): self._stored_metrics[train_eval][key].append(value)
metric_value, metric_reduction = value
else:
metric_value, metric_reduction = value, reduction
self._stored_metrics[train_eval][key]["values"].append(metric_value)
self._stored_metrics[train_eval][key]["reduction"] = metric_reduction
def _save_checkpoint(self, model, trial, **kwargs): def _save_checkpoint(self, model, trial, **kwargs):
# make sure the checkpoint dir exists, since trainer is flakey # make sure the checkpoint dir exists, since trainer is flakey

View File

@@ -147,7 +147,7 @@ class BasePlugin:
""" """
# pylint: disable=unused-argument # pylint: disable=unused-argument
def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None: def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
"""Returns a custom class for the trainer. """Returns a custom class for the trainer.
Args: Args:

View File

@@ -1,164 +0,0 @@
# Diffusion LM Training Plugin for Axolotl
This plugin enables diffusion language model training using the LLaDA (Large Language
And Diffusion Assistant) approach within the Axolotl framework.
## Overview
LLaDA is a diffusion-based approach to language model training that uses:
- **Random token masking** during training instead of next-token prediction
- **Bidirectional attention** to allow the model to see the full context
- **Importance weighting** based on masking probabilities for stable training
This approach can lead to more robust language models with better understanding of
bidirectional context.
## Installation
The plugin is included with Axolotl. To use it, simply add the plugin configuration to
your training config.
## Quickstart
### Basic Configuration
Add the following to your Axolotl configuration YAML:
```yaml
# Enable diffusion LM training plugin
plugins:
- axolotl.integrations.diffusion.DiffusionPlugin
# Diffusion-specific configuration
noise_schedule: linear # or "cosine"
min_mask_ratio: 0.1
max_mask_ratio: 0.9
num_diffusion_steps: 128
eps: 1e-3
importance_weighting: true
mask_token_id: 128002
# Sample generation (optional)
generate_samples: true
generation_interval: 100
num_generation_samples: 3
generation_steps: 128
generation_temperature: 0.0
generation_max_length: 100
# Model configuration
base_model: meta-llama/Llama-3.2-1B
model_type: llama
# Standard Axolotl configuration
datasets:
- path: your_dataset
...
# Other config
sequence_len: 1024
micro_batch_size: 8
gradient_accumulation_steps: 4
learning_rate: 3e-4
```
## Supported Models
Currently supported base model types:
- **Llama** (meta-llama/Llama-*, etc.) - Uses `LlamaForDiffusionLM`
- **Mistral** (mistralai/Mistral-*, etc.) - Uses `MistralForDiffusionLM`
The plugin automatically creates custom model classes that inherit from the base model
while adding diffusion training capabilities. This provides full compatibility with
HuggingFace's ecosystem for saving, loading, and inference.
## How It Works
### Custom Model Architecture
The plugin creates custom model classes (`LlamaForDiffusionLM`, `MistralForDiffusionLM`) that inherit from
standard HuggingFace models. During training, these models:
1. **Apply forward diffusion process**: Randomly mask tokens based on sampled timesteps
2. **Use bidirectional attention**: Override causal attention with full bidirectional attention
3. **Compute diffusion loss**: Calculate loss only on masked tokens with optional importance weighting
### Random Masking
During training, tokens are randomly masked based on a sampled timestep:
- Sample timestep `t` uniformly from [0, 1]
- Calculate masking probability: `p = (1 - eps) * t + eps`
- Randomly mask tokens with probability `p`
### Bidirectional Attention
The models override causal attention with bidirectional attention:
- Creates 4D attention masks allowing all-to-all attention
- Maintains proper padding and sample packing masks
- Compatible with standard HuggingFace attention implementations
### Diffusion Loss
Loss is computed only on masked tokens with (optional) importance weighting:
```python
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
```
### Model Loading and Saving
The custom models work seamlessly with HuggingFace's AutoModel system:
```python
from transformers import AutoModel, AutoConfig
# Load a diffusion model
model = AutoModel.from_pretrained("path/to/diffusion/model", trust_remote_code=True)
# Save a diffusion model
model.save_pretrained("path/to/save/diffusion/model")
```
During inference, the models behave like standard causal language models.
## Sample Generation
When `generate_samples: true`, the plugin generates samples during training:
```
Sample 1:
Original (45 tokens): The quick brown fox jumps over the lazy dog...
Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...
Generated: The quick brown fox jumps over the lazy dog...
```
Samples are logged to console and wandb (if enabled).
## Metrics and Monitoring
The plugin adds several metrics to track diffusion training:
- `train/loss`: Weighted diffusion loss
- `train/accuracy`: Accuracy on masked tokens
- `train/mask_ratio`: Average fraction of tokens masked
- `train/num_masked_tokens`: Number of tokens masked
- `train/avg_p_mask`: Average masking probability
- `train/ce_loss`: Unweighted cross-entropy loss
- `train/importance_weight_avg`: Average importance weight
## Benefits of Custom Model Approach
**Type Safety**: Full IDE support and type checking
**HuggingFace Integration**: Works with AutoModel, Hub, pipelines
**Maintainability**: Clean architecture, no monkey patching
**Ecosystem Compatibility**: Standard save/load, PEFT support
**Testing**: Easier to test and debug
## Limitations
- **Model Support**: Currently limited to Llama and Mistral architectures
- **Flash Attention**: Not yet optimized for flash attention
- **Inference Speed**: Bidirectional attention is slower than causal for generation
## References
- [LLaDA Paper](https://arxiv.org/abs/2404.10406)
- [Axolotl Documentation](https://docs.axolotl.ai/)

View File

@@ -1,26 +0,0 @@
"""Diffusion LM training plugin init."""
from transformers import AutoConfig, AutoModel
from .args import DiffusionArgs
from .configuration import DiffusionConfig, LlamaForDiffusionConfig, MistralForDiffusionConfig
from .models import LlamaForDiffusionLM, MistralForDiffusionLM
from .plugin import DiffusionPlugin
# Register custom configurations
AutoConfig.register("llama_diffusion", LlamaForDiffusionConfig)
AutoConfig.register("mistral_diffusion", MistralForDiffusionConfig)
# Register custom models
AutoModel.register(LlamaForDiffusionConfig, LlamaForDiffusionLM)
AutoModel.register(MistralForDiffusionConfig, MistralForDiffusionLM)
__all__ = [
"DiffusionArgs",
"DiffusionPlugin",
"DiffusionConfig",
"LlamaForDiffusionConfig",
"MistralForDiffusionConfig",
"LlamaForDiffusionLM",
"MistralForDiffusionLM",
]

View File

@@ -1,70 +0,0 @@
"""Config args for diffusion LM training."""
from typing import Literal
from pydantic import BaseModel, Field
class DiffusionArgs(BaseModel):
"""Arguments for diffusion LM training plugin."""
# Noise schedule config
noise_schedule: Literal["linear", "cosine"] = Field(
default="linear", description="Type of noise schedule for diffusion training"
)
min_mask_ratio: float = Field(
default=0.1,
ge=0.0,
le=1.0,
description="Minimum masking ratio for diffusion noise schedule",
)
max_mask_ratio: float = Field(
default=0.9,
ge=0.0,
le=1.0,
description="Maximum masking ratio for diffusion noise schedule",
)
num_diffusion_steps: int = Field(
default=128, ge=1, description="Number of diffusion timesteps"
)
eps: float = Field(
default=1e-3,
ge=0.0,
le=1.0,
description="Epsilon value for minimum masking probability in forward process",
)
# Training config
importance_weighting: bool = Field(
default=True,
description="Apply importance weighting to loss based on masking probability",
)
mask_token_id: int = Field(
default=128002,
description=(
"Token ID to use for masking. Default is 128002 "
"(<|reserved_special_token_0|> for Llama 3.2)"
),
)
# Sample generation config
generate_samples: bool = Field(
default=True, description="Enable sample generation during training"
)
generation_interval: int = Field(
default=100, ge=1, description="Generate samples every N steps"
)
num_generation_samples: int = Field(
default=3, ge=1, description="Number of samples to generate each time"
)
generation_steps: int = Field(
default=128, ge=1, description="Number of diffusion steps for generation"
)
generation_temperature: float = Field(
default=0.0,
ge=0.0,
description="Temperature for generation sampling (0.0 = deterministic)",
)
generation_max_length: int = Field(
default=100, ge=1, description="Maximum sequence length for generation"
)

View File

@@ -1,116 +0,0 @@
"""Callbacks for diffusion training."""
import wandb
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from axolotl.utils.logging import get_logger
from .generation import generate_samples
LOG = get_logger(__name__)
class DiffusionGenerationCallback(TrainerCallback):
"""Callback for generating samples during diffusion training."""
def __init__(self, trainer):
self.trainer = trainer
# pylint: disable=unused-argument
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Generate samples at specified intervals."""
config = getattr(self.trainer, 'diffusion_config', self.trainer.args)
if (
state.global_step > 0
and state.global_step % config.get('generation_interval', 100) == 0
):
# Use eval dataloader if available, otherwise use train dataloader
if (
hasattr(self.trainer, "eval_dataset")
and self.trainer.eval_dataset is not None
):
dataloader = self.trainer.get_eval_dataloader()
else:
dataloader = self.trainer.get_train_dataloader()
# Generate samples
samples = generate_samples(
model=self.trainer.model,
tokenizer=self.trainer.tokenizer,
dataloader=dataloader,
num_generation_samples=config.get('num_generation_samples', 3),
max_length=config.get('generation_max_length', 256),
num_diffusion_steps=config.get('generation_steps', 10),
temperature=config.get('generation_temperature', 1.0),
mask_token_id=config.get('mask_token_id', 32000),
)
# Log samples
self._log_samples(samples, state.global_step)
def _log_samples(self, samples: list, step: int):
"""Log generated samples."""
if not samples:
return
LOG.info("=" * 60)
LOG.info("GENERATED SAMPLES")
LOG.info("=" * 60)
for i, sample_data in enumerate(samples, 1):
original = sample_data["original"]
masked = sample_data["masked"]
generated = sample_data["generated"]
mask_ratio = sample_data["mask_ratio"]
masked_tokens = sample_data["masked_tokens"]
total_tokens = sample_data["total_tokens"]
LOG.info(f"\nSample {i}:")
LOG.info(f"\tOriginal ({total_tokens} tokens): {original}")
LOG.info(
f"\tMasked ({masked_tokens}/{total_tokens} tokens, "
f"{mask_ratio:.1%}): {masked}"
)
LOG.info(f"\tGenerated: {generated}")
LOG.info("=" * 60)
config = getattr(self.trainer, 'diffusion_config', self.trainer.args)
if config.get('use_wandb', False) and self.trainer.state.is_world_process_zero:
if wandb.run is not None:
wandb.log(
{
"generated_samples": wandb.Table(
columns=[
"step",
"original",
"masked",
"generated",
"mask_ratio",
"masked_tokens",
"total_tokens",
],
data=[
[
step,
sample["original"],
sample["masked"],
sample["generated"],
f"{sample['mask_ratio']:.1%}",
sample["masked_tokens"],
sample["total_tokens"],
]
for sample in samples
],
)
},
step=step,
)

View File

@@ -1,71 +0,0 @@
"""Configuration classes for diffusion language models."""
from transformers import LlamaConfig, MistralConfig
class LlamaForDiffusionConfig(LlamaConfig):
"""Configuration class for Llama models with diffusion training."""
model_type = "llama_diffusion"
def __init__(
self,
mask_token_id: int = 32000,
eps: float = 1e-3,
importance_weighting: bool = False,
sample_packing: bool = False,
min_mask_ratio: float = 0.0,
max_mask_ratio: float = 1.0,
noise_schedule: str = "linear",
**kwargs,
):
super().__init__(**kwargs)
# Diffusion-specific parameters
self.mask_token_id = mask_token_id
self.eps = eps
self.importance_weighting = importance_weighting
self.sample_packing = sample_packing
self.min_mask_ratio = min_mask_ratio
self.max_mask_ratio = max_mask_ratio
self.noise_schedule = noise_schedule
class MistralForDiffusionConfig(MistralConfig):
"""Configuration class for Mistral models with diffusion training."""
model_type = "mistral_diffusion"
def __init__(
self,
mask_token_id: int = 32000,
eps: float = 1e-3,
importance_weighting: bool = False,
sample_packing: bool = False,
min_mask_ratio: float = 0.0,
max_mask_ratio: float = 1.0,
noise_schedule: str = "linear",
**kwargs,
):
super().__init__(**kwargs)
# Diffusion-specific parameters
self.mask_token_id = mask_token_id
self.eps = eps
self.importance_weighting = importance_weighting
self.sample_packing = sample_packing
self.min_mask_ratio = min_mask_ratio
self.max_mask_ratio = max_mask_ratio
self.noise_schedule = noise_schedule
# Keep the base class for backward compatibility but mark as deprecated
class DiffusionConfig(LlamaForDiffusionConfig):
"""
Deprecated: Use LlamaForDiffusionConfig or MistralForDiffusionConfig instead.
"""
model_type = "diffusion"
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -1,269 +0,0 @@
"""Sample generation utilities for diffusion training."""
import logging
from typing import Any, List, Optional
import torch
logger = logging.getLogger(__name__)
def generate_samples(
model: torch.nn.Module,
tokenizer: Any,
dataloader: Optional[Any] = None,
num_generation_samples: int = 3,
max_length: int = 100,
num_diffusion_steps: int = 128,
temperature: float = 0.0,
mask_token_id: int = 32000,
) -> List[dict]:
"""
Generate text samples using the diffusion model by randomly masking sequences from
the given dataset and running the reverse diffusion process.
Args:
model: The wrapped or unwrapped model
tokenizer: Tokenizer for encoding/decoding
dataloader: Validation dataloader (for sampling sequences)
num_generation_samples: Number of samples to generate
max_length: Maximum length of sequences to use
num_diffusion_steps: Number of diffusion steps for generation
temperature: Temperature for sampling (0.0 = deterministic)
mask_token_id: Token ID used for masking
Returns:
List of dictionaries with original text, masked text, and generated text
"""
if dataloader is None:
logger.warning("No validation dataloader provided, cannot generate samples")
return []
# Get the actual model (unwrap if needed)
unwrapped_model = model.module if hasattr(model, "module") else model
unwrapped_model.eval()
generations = []
# Sample sequences from validation dataset
sampled_sequences = _sample_sequences_from_dataloader(
dataloader, num_generation_samples, max_length, unwrapped_model.device
)
logger.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset")
# Generate samples using reverse diffusion process
with torch.no_grad():
for original_sequence in sampled_sequences:
generation_result = _generate(
unwrapped_model,
tokenizer,
original_sequence,
num_diffusion_steps,
temperature,
mask_token_id,
)
generations.append(generation_result)
unwrapped_model.train()
return generations
def _sample_sequences_from_dataloader(
dataloader: Any, num_samples: int, max_length: int, device: torch.device
) -> List[torch.Tensor]:
"""Sample sequences from validation dataloader."""
sampled_sequences = []
sample_count = 0
# Add randomness by skipping a random number of batches
skip_batches = torch.randint(0, 6, (1,)).item()
batch_count = 0
for batch in dataloader:
# Skip some batches for variety
if batch_count < skip_batches:
batch_count += 1
continue
if sample_count >= num_samples:
break
batch_count += 1
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask")
# Randomly sample from sequences in this batch
batch_indices = torch.randperm(input_ids.size(0)).tolist()
for i in batch_indices:
if sample_count >= num_samples:
break
# Get actual sequence length (non-padded)
if attention_mask is not None:
seq_len = attention_mask[i].sum().item()
else:
seq_len = input_ids.size(1)
# Limit sequence length to max_length
actual_length = min(seq_len, max_length)
if actual_length < 10: # Skip very short sequences
continue
# Extract the sequence
sequence = input_ids[i][:actual_length].unsqueeze(0).to(device)
sampled_sequences.append(sequence)
sample_count += 1
return sampled_sequences
def _generate(
model: torch.nn.Module,
tokenizer: Any,
original_sequence: torch.Tensor,
num_diffusion_steps: int,
temperature: float,
mask_token_id: int,
) -> dict:
"""Generate a single sample using reverse diffusion."""
# Get original text for comparison
original_text = tokenizer.decode(
original_sequence[0].cpu(), skip_special_tokens=True
)
# Apply custom masking with random ratio (10% to 70%)
total_tokens = original_sequence.size(1)
min_ratio, max_ratio = 0.1, 0.7
target_mask_ratio = torch.rand(1).item() * (max_ratio - min_ratio) + min_ratio
target_masked_tokens = int(total_tokens * target_mask_ratio)
# Create random mask indices
mask_positions = torch.randperm(total_tokens)[:target_masked_tokens]
masked_indices = torch.zeros(
1, total_tokens, dtype=torch.bool, device=original_sequence.device
)
masked_indices[0, mask_positions] = True
# Create masked sequence
masked_sequence = original_sequence.clone()
masked_sequence[masked_indices] = mask_token_id
# Calculate actual mask ratio
masked_tokens = masked_indices.sum().item()
mask_ratio = masked_tokens / total_tokens
# Get masked text for comparison
masked_text = tokenizer.decode(masked_sequence[0].cpu(), skip_special_tokens=False)
# Clean up mask token representation
masked_text = _clean_masked_text(masked_text, tokenizer, mask_token_id)
# Run reverse diffusion process
sequence = masked_sequence.clone()
for step in range(num_diffusion_steps):
sequence = _diffusion_step(
model, sequence, step, num_diffusion_steps, temperature, mask_token_id
)
# Get final generated text
generated_text = tokenizer.decode(sequence[0].cpu(), skip_special_tokens=True)
return {
"original": original_text,
"masked": masked_text,
"generated": generated_text,
"mask_ratio": mask_ratio,
"masked_tokens": masked_tokens,
"total_tokens": total_tokens,
"formatted": (
f"Original: '{original_text}' → Masked: '{masked_text}' "
f"({mask_ratio:.1%}) → Generated: '{generated_text}'"
),
}
def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str:
"""Clean up masked text for display."""
mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False)
cleaned = masked_text.replace(mask_token_repr, "[MASK]")
if hasattr(tokenizer, "special_tokens_map"):
for token_value in tokenizer.special_tokens_map.values():
if token_value and isinstance(token_value, str):
cleaned = cleaned.replace(token_value, "")
cleaned = " ".join(cleaned.split()).strip()
return cleaned
def _diffusion_step(
model: torch.nn.Module,
sequence: torch.Tensor,
step: int,
num_diffusion_steps: int,
temperature: float,
mask_token_id: int,
) -> torch.Tensor:
"""Perform a single diffusion step with remasking."""
# Only process if there are masked tokens remaining
current_mask = sequence == mask_token_id
if not current_mask.any():
return sequence
# Create bidirectional attention mask for diffusion
batch_size, seq_len = sequence.shape
attention_mask = torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=sequence.device
)
# Forward pass
outputs = model(input_ids=sequence, attention_mask=attention_mask)
logits = outputs.logits
# Only sample at currently masked positions
if current_mask.any():
masked_logits = logits[current_mask]
# Apply temperature scaling
if temperature > 0:
scaled_logits = masked_logits / temperature
else:
scaled_logits = masked_logits
# Suppress mask token in outputs
scaled_logits[:, mask_token_id] = -float("inf")
# Sample predictions
if temperature > 0:
# Add Gumbel noise for sampling
gumbel_noise = -torch.log(
-torch.log(torch.rand_like(scaled_logits, dtype=torch.float32))
)
gumbel_logits = scaled_logits + gumbel_noise
predicted_tokens = torch.argmax(gumbel_logits, dim=-1)
else:
# Deterministic sampling when temperature is 0
predicted_tokens = torch.argmax(scaled_logits, dim=-1)
# Calculate probabilities for confidence scoring
probs = torch.softmax(scaled_logits, dim=-1)
predicted_token_probs = probs[range(len(predicted_tokens)), predicted_tokens]
# Determine how many tokens to unmask this step
remaining_masked = current_mask.sum().item()
if step == num_diffusion_steps - 1:
num_to_unmask = remaining_masked
else:
unmask_ratio = 1.0 / (num_diffusion_steps - step)
num_to_unmask = max(1, int(remaining_masked * unmask_ratio))
# Select highest confidence predictions to unmask
if num_to_unmask >= remaining_masked:
sequence[current_mask] = predicted_tokens
else:
_, top_indices = predicted_token_probs.topk(num_to_unmask)
mask_positions = torch.where(current_mask)[1]
positions_to_unmask = mask_positions[top_indices]
sequence[0, positions_to_unmask] = predicted_tokens[top_indices]
return sequence

View File

@@ -1,426 +0,0 @@
"""Custom model classes for diffusion language models."""
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import LlamaForCausalLM, MistralForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration import LlamaForDiffusionConfig, MistralForDiffusionConfig
class DiffusionModelMixin:
"""Mixin class providing diffusion functionality to language models."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._special_token_ids = None
def _cache_special_token_ids(self, tokenizer=None):
"""Cache special token IDs to avoid repeated tokenizer access."""
if tokenizer is None:
self._special_token_ids = set()
return
special_tokens = set()
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
special_tokens.add(tokenizer.bos_token_id)
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
special_tokens.add(tokenizer.eos_token_id)
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
special_tokens.add(tokenizer.pad_token_id)
self._special_token_ids = special_tokens
def _forward_process(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
eps: float = 1e-3,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward noising process. A timestep is sampled along the process, and tokens are
masked with probability determined by the configured noise schedule.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
eps: Small epsilon value for minimum masking probability.
Returns:
noisy_batch: Input with some tokens masked.
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Sample random timesteps for each sample in batch
t = torch.rand(batch_size, device=device)
# Calculate masking probability with epsilon
p_mask = (1 - eps) * t + eps # [batch_size]
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
# Don't mask padding tokens if attention_mask is provided
if attention_mask is not None:
valid_mask = attention_mask.bool()
p_mask = p_mask * valid_mask.float()
# Create mask to exclude special tokens
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
if self._special_token_ids:
for token_id in self._special_token_ids:
special_token_mask |= input_ids == token_id
# Create random mask based on p_mask
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
masked_indices = masked_indices & ~special_token_mask
if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool()
# For SFT data, only mask answer tokens
if labels is not None:
answer_mask = labels != -100
masked_indices = masked_indices & answer_mask
# Create masked input
mask_token_id = self.config.mask_token_id
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
return noisy_batch, masked_indices, p_mask
def _create_bidirectional_attention_mask(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
"""
Create bidirectional attention mask to override default causal masking. Handles
sample-packed sequences where different samples are identified by different
attention mask values.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len]
Returns:
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
if attention_mask is None or not self.config.sample_packing:
return torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
)
# Create attention mask by comparing sample IDs element-wise
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
# Tokens can attend to each other if they have the same non-zero sample ID
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
# Add head dimension: [batch_size, 1, seq_len, seq_len]
bidirectional_mask = bidirectional_mask.unsqueeze(1)
return bidirectional_mask
def _compute_diffusion_loss(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
logits: torch.Tensor | None = None,
masked_indices: torch.Tensor | None = None,
p_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Compute diffusion loss given logits and masking information.
Args:
input_ids: Ground truth token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
logits: Model logits [batch_size, seq_len, vocab_size].
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
Returns:
loss: Cross-entropy loss.
"""
if masked_indices.sum() > 0:
valid_indices = torch.where(masked_indices)
batch_indices, seq_indices = valid_indices
masked_logits = logits[batch_indices, seq_indices]
masked_targets = input_ids[batch_indices, seq_indices]
masked_p_mask = p_mask[batch_indices, seq_indices]
# Compute cross-entropy loss without reduction
token_loss = F.cross_entropy(
masked_logits.float(), masked_targets, reduction="none"
)
if self.config.importance_weighting:
masked_p_mask = masked_p_mask.float()
weighted_loss = token_loss / masked_p_mask
else:
weighted_loss = token_loss
# Final loss: sum weighted losses, normalize
if labels is not None:
# For SFT data: normalize by answer length per sample
answer_mask = labels != -100
answer_lengths = answer_mask.sum(dim=1).float() # [batch_size]
# Get batch indices for masked tokens
masked_batch_indices = batch_indices
# Sum losses per sample and divide by answer length
loss_per_sample = torch.zeros(
input_ids.shape[0], device=input_ids.device
)
for i in range(input_ids.shape[0]):
sample_mask = masked_batch_indices == i
if sample_mask.sum() > 0:
sample_loss = weighted_loss[sample_mask].sum()
loss_per_sample[i] = sample_loss / answer_lengths[i]
loss = loss_per_sample.mean()
else:
# Original normalization for non-SFT data
loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
else:
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
return loss
class LlamaForDiffusionLM(DiffusionModelMixin, LlamaForCausalLM):
"""
Llama model for diffusion language modeling.
This model extends LlamaForCausalLM with diffusion training capabilities,
including bidirectional attention and forward diffusion process.
"""
config_class = LlamaForDiffusionConfig
def __init__(self, config):
super().__init__(config)
# Initialize diffusion-specific attributes
self._special_token_ids = None
# Initialize weights and apply final processing
self.post_init()
def set_tokenizer(self, tokenizer):
"""Set tokenizer for special token handling."""
self._cache_special_token_ids(tokenizer)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Forward pass with diffusion training logic.
During training, applies forward diffusion process and bidirectional attention.
During inference, behaves like standard causal language model.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.training and input_ids is not None:
# Apply diffusion process during training
original_input_ids = input_ids.clone()
# Apply forward process to get noisy input
noisy_input_ids, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self.config.eps
)
# Create bidirectional attention mask
bidirectional_attention_mask = self._create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Forward pass with noisy input and bidirectional attention
outputs = super().forward(
input_ids=noisy_input_ids,
attention_mask=bidirectional_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=None, # Don't use standard loss computation
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
# Compute diffusion loss
loss = self._compute_diffusion_loss(
original_input_ids,
attention_mask,
labels,
outputs.logits,
masked_indices,
p_mask,
)
if return_dict:
outputs.loss = loss
return outputs
else:
return (loss,) + outputs[1:]
else:
# Standard forward pass for inference
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
class MistralForDiffusionLM(DiffusionModelMixin, MistralForCausalLM):
"""
Mistral model for diffusion language modeling.
This model extends MistralForCausalLM with diffusion training capabilities,
including bidirectional attention and forward diffusion process.
"""
config_class = MistralForDiffusionConfig
def __init__(self, config):
super().__init__(config)
# Initialize diffusion-specific attributes
self._special_token_ids = None
# Initialize weights and apply final processing
self.post_init()
def set_tokenizer(self, tokenizer):
"""Set tokenizer for special token handling."""
self._cache_special_token_ids(tokenizer)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Forward pass with diffusion training logic.
During training, applies forward diffusion process and bidirectional attention.
During inference, behaves like standard causal language model.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.training and input_ids is not None:
# Apply diffusion process during training
original_input_ids = input_ids.clone()
# Apply forward process to get noisy input
noisy_input_ids, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self.config.eps
)
# Create bidirectional attention mask
bidirectional_attention_mask = self._create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Forward pass with noisy input and bidirectional attention
outputs = super().forward(
input_ids=noisy_input_ids,
attention_mask=bidirectional_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=None, # Don't use standard loss computation
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
# Compute diffusion loss
loss = self._compute_diffusion_loss(
original_input_ids,
attention_mask,
labels,
outputs.logits,
masked_indices,
p_mask,
)
if return_dict:
outputs.loss = loss
return outputs
else:
return (loss,) + outputs[1:]
else:
# Standard forward pass for inference
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

View File

@@ -1,98 +0,0 @@
"""Diffusion LM training plugin for Axolotl."""
from typing import TYPE_CHECKING
from peft import PeftModel
from transformers import AutoConfig, AutoModel, PreTrainedModel
from axolotl.integrations.base import BasePlugin
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .callbacks import DiffusionGenerationCallback
from .configuration import LlamaForDiffusionConfig, MistralForDiffusionConfig
from .models import LlamaForDiffusionLM, MistralForDiffusionLM
if TYPE_CHECKING:
from transformers import Trainer
LOG = get_logger(__name__)
class DiffusionPlugin(BasePlugin):
"""
Plugin for diffusion language model training.
This plugin enables diffusion-based training using the LLaDA approach, which uses
random masking and bidirectional attention to train language models.
"""
def __init__(self):
super().__init__()
self.cfg = None
def get_input_args(self) -> str:
"""Returns the pydantic model for LLaDA plugin arguments."""
return "axolotl.integrations.diffusion.DiffusionArgs"
def pre_model_load(self, cfg: DictDefault):
"""Configure model loading to use diffusion model classes."""
# Map base model types to diffusion equivalents
base_model_type = cfg.get("model_type")
if base_model_type == "llama":
# Create diffusion config from base config
diffusion_config = LlamaForDiffusionConfig(
mask_token_id=getattr(cfg, "mask_token_id", 32000),
eps=getattr(cfg, "eps", 1e-3),
importance_weighting=getattr(cfg, "importance_weighting", False),
sample_packing=getattr(cfg, "sample_packing", False),
min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.0),
max_mask_ratio=getattr(cfg, "max_mask_ratio", 1.0),
noise_schedule=getattr(cfg, "noise_schedule", "linear"),
)
# Override model type for loading
cfg.model_type = "llama_diffusion"
elif base_model_type == "mistral":
# Create diffusion config from base config
diffusion_config = MistralForDiffusionConfig(
mask_token_id=getattr(cfg, "mask_token_id", 32000),
eps=getattr(cfg, "eps", 1e-3),
importance_weighting=getattr(cfg, "importance_weighting", False),
sample_packing=getattr(cfg, "sample_packing", False),
min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.0),
max_mask_ratio=getattr(cfg, "max_mask_ratio", 1.0),
noise_schedule=getattr(cfg, "noise_schedule", "linear"),
)
# Override model type for loading
cfg.model_type = "mistral_diffusion"
else:
LOG.warning(f"Diffusion plugin not implemented for model type: {base_model_type}")
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Configure model after loading."""
self.cfg = cfg
# Set tokenizer on diffusion models for special token handling
if hasattr(model, "set_tokenizer"):
# Get tokenizer from cfg if available
tokenizer = getattr(cfg, "tokenizer", None)
if tokenizer is not None:
model.set_tokenizer(tokenizer)
def add_callbacks_post_trainer(self, cfg: DictDefault, trainer: "Trainer"):
"""Add diffusion-specific callbacks after trainer creation."""
callbacks = []
# Store diffusion config on trainer for callbacks
trainer.diffusion_config = cfg
# Add generation callback if enabled
if cfg.get("generate_samples", False):
generation_callback = DiffusionGenerationCallback(trainer)
callbacks.append(generation_callback)
return callbacks

View File

@@ -681,23 +681,6 @@ class ModelLoader:
return hf_ds_cfg return hf_ds_cfg
def _load_model_from_config(self) -> PreTrainedModel:
"""Load model with random initialization using from_config."""
if self.auto_model_loader in [AutoModelForCausalLM, AutoModelForVision2Seq]:
return self.auto_model_loader.from_config(config=self.model_config)
return self.auto_model_loader(config=self.model_config)
def _load_model_from_pretrained(self, model_loader_class=None) -> PreTrainedModel:
"""Load model from pretrained weights."""
loader = model_loader_class or self.auto_model_loader
kwargs = {
**self.model_kwargs,
"config": self.model_config,
"trust_remote_code": self.cfg.trust_remote_code or False,
**self.model_kwargs,
}
return loader.from_pretrained(self.base_model, **kwargs)
def _build_model(self) -> bool: def _build_model(self) -> bool:
"""Load model, with load strategy depending on config.""" """Load model, with load strategy depending on config."""
skip_move_to_device = False skip_move_to_device = False
@@ -712,8 +695,7 @@ class ModelLoader:
if self.is_fsdp_enabled: if self.is_fsdp_enabled:
if self.cfg.fsdp_config.cpu_ram_efficient_loading: if self.cfg.fsdp_config.cpu_ram_efficient_loading:
skip_move_to_device = True skip_move_to_device = True
# Don't delete device_map for QLoRA + FSDP - it was set correctly in # Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map
# _set_device_map
if ( if (
"device_map" in self.model_kwargs "device_map" in self.model_kwargs
and not self.is_qlora_and_fsdp_enabled and not self.is_qlora_and_fsdp_enabled
@@ -742,11 +724,6 @@ class ModelLoader:
or self.cfg.qlora_sharded_model_loading or self.cfg.qlora_sharded_model_loading
) )
): ):
if self.cfg.reinit_weights:
LOG.warning(
"reinit_weights is not supported with sharded quantized loading. "
"Loading from pretrained weights instead."
)
quant_storage = self.cfg.torch_dtype quant_storage = self.cfg.torch_dtype
quantization_config = getattr( quantization_config = getattr(
self.model_config, "quantization_config", None self.model_config, "quantization_config", None
@@ -762,12 +739,33 @@ class ModelLoader:
quantization_config=quantization_config, quantization_config=quantization_config,
) )
skip_move_to_device = True skip_move_to_device = True
elif self.model_type == "MambaLMHeadModel": elif (
if self.cfg.reinit_weights: self.model_config.model_type in ["llama", "llama4"]
LOG.warning( and not self.cfg.trust_remote_code
"reinit_weights is not supported with MambaLMHeadModel. " and not self.cfg.gptq
"Loading from pretrained weights instead." ):
# Please don't remove underscore binding without reading the fn docstring.
_ = self._configure_zero3_memory_efficient_loading()
# Load model with random initialization if specified
if self.cfg.random_init_weights:
# AutoModel classes support the from_config method
if self.auto_model_loader in [
AutoModelForCausalLM,
AutoModelForVision2Seq,
]:
self.model = self.auto_model_loader.from_config(
config=self.model_config,
)
else:
self.model = self.auto_model_loader(config=self.model_config)
else:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
**self.model_kwargs,
) )
elif self.model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work # FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
@@ -780,27 +778,41 @@ class ModelLoader:
self.base_model, self.base_model,
**self.model_kwargs, **self.model_kwargs,
) )
elif (
self.model_type
and self.model_type != "AutoModelForCausalLM"
and not self.cfg.trust_remote_code
):
if self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
self.model = getattr(transformers, self.model_type).from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
elif self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else: else:
# Please don't remove underscore binding without reading the fn docstring # Please don't remove underscore binding without reading the fn docstring.
_ = self._configure_zero3_memory_efficient_loading() _ = self._configure_zero3_memory_efficient_loading()
self.model = self.auto_model_loader.from_pretrained(
if ( self.base_model,
self.model_type config=self.model_config,
and self.model_type != "AutoModelForCausalLM" trust_remote_code=self.cfg.trust_remote_code or False,
and not self.cfg.trust_remote_code **self.model_kwargs,
and not self.cfg.gptq )
):
# Use model type from transformers
model_loader_class = getattr(transformers, self.model_type)
else:
# Use auto model loader (handles gptq and default cases)
model_loader_class = self.auto_model_loader
if self.cfg.reinit_weights:
self.model = self._load_model_from_config()
else:
self.model = self._load_model_from_pretrained(model_loader_class)
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
skip_move_to_device = True skip_move_to_device = True

View File

@@ -187,7 +187,7 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to # Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight # wrap this. Therefore we must ensure the bias has the same dtype as the weight
if module.base_layer.bias is not None: if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype: if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to( module.base_layer.bias.data = module.base_layer.bias.data.to(

View File

@@ -72,9 +72,10 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
builder_kwargs["message_field_training"] = message_field_training builder_kwargs["message_field_training"] = message_field_training
chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml")) chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml"))
format_message = (
lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment def format_message(x):
) return x
if chat_template == "chatml": if chat_template == "chatml":
from axolotl.core.chat.format.chatml import format_message # noqa F811 from axolotl.core.chat.format.chatml import format_message # noqa F811
if chat_template.startswith("llama3"): if chat_template.startswith("llama3"):

View File

@@ -75,7 +75,7 @@ class PromptTokenizingStrategy(abc.ABC):
) -> BatchEncoding: ) -> BatchEncoding:
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []}) empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
if not prompt: if not prompt:
LOG.warning_once("Empty text requested for tokenization.") LOG.warning("Empty text requested for tokenization.")
return empty return empty
result = self.tokenizer( result = self.tokenizer(

View File

@@ -253,7 +253,9 @@ def save_trained_model(
# final model weights have already been saved by `ReLoRACallback.on_train_end` # final model weights have already been saved by `ReLoRACallback.on_train_end`
return return
if trainer.is_fsdp_enabled or cfg.fsdp_config: if ( # pylint: disable=too-many-nested-blocks
trainer.is_fsdp_enabled or cfg.fsdp_config
):
if cfg.fsdp_config or cfg.fsdp: if cfg.fsdp_config or cfg.fsdp:
if cfg.fsdp_config.final_state_dict_type: if cfg.fsdp_config.final_state_dict_type:
state_dict_type = cfg.fsdp_config.final_state_dict_type state_dict_type = cfg.fsdp_config.final_state_dict_type
@@ -285,6 +287,8 @@ def save_trained_model(
if trainer.accelerator.is_main_process: if trainer.accelerator.is_main_process:
# move all files in merged_path to cfg.output_dir # move all files in merged_path to cfg.output_dir
for merged_file in Path(merged_path).iterdir(): for merged_file in Path(merged_path).iterdir():
if (Path(cfg.output_dir) / merged_file.name).exists():
(Path(cfg.output_dir) / merged_file.name).unlink()
shutil.move(str(merged_file), cfg.output_dir) shutil.move(str(merged_file), cfg.output_dir)
shutil.rmtree(merged_path) # remove what should be an empty dir shutil.rmtree(merged_path) # remove what should be an empty dir
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207 # TODO(wing):see https://github.com/huggingface/transformers/pull/40207

216
src/axolotl/tui/README.md Normal file
View File

@@ -0,0 +1,216 @@
# Axolotl TUI (Terminal User Interface)
A comprehensive Terminal User Interface for Axolotl, providing an interactive way to manage configurations, training jobs, datasets, models, and system monitoring.
## Features
### 🏠 Main Dashboard
- **Welcome Screen**: Central hub with quick access to all features
- **Keyboard Navigation**: Efficient navigation with keyboard shortcuts
- **Screen Management**: Easy switching between different functional areas
### 📝 Configuration Management
- **YAML Editor**: Syntax-highlighted editor for Axolotl configurations
- **Real-time Validation**: Instant config validation with detailed error reporting
- **File Browser**: Navigate and select configuration files
- **Template Loading**: Load example configurations
- **Remote Config Support**: Load configurations from URLs
**Key Shortcuts:**
- `Ctrl+N`: New configuration
- `Ctrl+S`: Save configuration
- `Ctrl+V`: Validate configuration
- `Ctrl+E`: Toggle edit mode
### 🚀 Training Management
- **Job Launcher**: Start training with different launchers (accelerate, torchrun)
- **Real-time Monitoring**: Live training progress and metrics
- **Loss Visualization**: Sparkline charts for loss curves
- **Job Control**: Start, stop, resume, and manage multiple training jobs
- **Log Streaming**: Real-time log viewing and filtering
**Key Shortcuts:**
- `Ctrl+T`: New training job
- `Ctrl+R`: Resume training
- `Ctrl+X`: Stop training
- `R`: Refresh status
### 📊 Dataset Management
- **Dataset Browser**: Explore local and remote datasets
- **Preview & Statistics**: View dataset samples and metadata
- **Preprocessing**: Run dataset preprocessing with progress tracking
- **HuggingFace Integration**: Download and manage HF datasets
- **Format Detection**: Automatic dataset format recognition
**Key Shortcuts:**
- `Ctrl+P`: Preprocess dataset
- `Ctrl+V`: Preview dataset
- `Ctrl+I`: Dataset information
- `R`: Refresh dataset list
### 🤖 Model Management
- **Model Discovery**: Automatically find trained models
- **LoRA Operations**: Merge LoRA adapters with base models
- **Quantization**: Quantize models for deployment
- **Evaluation**: Run model evaluation benchmarks
- **Storage Info**: View model sizes and storage details
**Key Shortcuts:**
- `Ctrl+M`: Merge LoRA
- `Ctrl+Q`: Quantize model
- `Ctrl+E`: Evaluate model
- `R`: Refresh model list
### 💬 Inference & Testing
- **Interactive Chat**: Chat interface for model testing
- **Parameter Tuning**: Adjust inference parameters (temperature, top-p, max tokens)
- **Model Loading**: Load and switch between different models
- **Chat History**: Save and load conversation history
- **Gradio Integration**: Launch Gradio web interface
**Key Shortcuts:**
- `Ctrl+Enter`: Send message
- `Ctrl+C`: Clear chat
- `Ctrl+L`: Load model
- `Ctrl+S`: Save chat
### 📈 System Monitoring
- **Resource Monitoring**: Real-time CPU, GPU, and memory usage
- **Process Management**: View and manage running processes
- **Performance Graphs**: Historical usage charts with sparklines
- **GPU Information**: Detailed GPU status and memory usage
- **Temperature Monitoring**: System temperature tracking
**Key Shortcuts:**
- `R`: Refresh metrics
- `Ctrl+K`: Kill selected process
## Installation
### Dependencies
```bash
pip install textual==1.0.0 rich==14.1.0
```
### Launch TUI
```bash
# From command line
python -m axolotl.cli.main tui
# From Python code
from axolotl.tui.app import run
run()
```
## Architecture
### Screen Structure
```
AxolotlTUI (Main App)
├── WelcomeScreen (Dashboard)
├── ConfigScreen (Configuration Management)
├── TrainingScreen (Training Management)
├── DatasetScreen (Dataset Management)
├── ModelScreen (Model Management)
├── InferenceScreen (Inference & Testing)
└── MonitorScreen (System Monitoring)
```
### Key Components
- **BaseScreen**: Common functionality for all screens
- **Screen Navigation**: Stack-based screen management
- **Event Handling**: Reactive UI updates
- **Background Tasks**: Non-blocking operations
- **State Management**: Shared application state
### Integration Points
- **CLI Commands**: Seamless integration with existing axolotl CLI
- **Configuration System**: Uses axolotl's native config loading
- **Training Pipeline**: Integrates with axolotl training functions
- **Model Loading**: Compatible with axolotl model management
## Usage Examples
### 1. Creating a New Configuration
1. Launch TUI: `python -m axolotl.cli.main tui`
2. Select "Configuration Management" or press `C`
3. Press `Ctrl+N` for new configuration
4. Edit the template configuration
5. Press `Ctrl+V` to validate
6. Press `Ctrl+S` to save
### 2. Starting a Training Job
1. Navigate to "Training Management" or press `T`
2. Press `Ctrl+T` for new training job
3. Select configuration file and launcher
4. Monitor progress in real-time
5. View loss curves and logs
### 3. Interactive Model Testing
1. Go to "Inference & Testing" or press `I`
2. Load a trained model with `Ctrl+L`
3. Adjust inference parameters as needed
4. Start chatting with the model
5. Save conversation with `Ctrl+S`
## Navigation
### Global Shortcuts
- `Ctrl+Q`: Quit application
- `Escape`: Go back/close current screen
- `Tab`: Navigate between UI elements
- `Enter`: Select/activate element
- `Space`: Toggle switches/checkboxes
### Screen Shortcuts
Each screen has specific shortcuts displayed in the footer. Common patterns:
- `Ctrl+[Letter]`: Primary actions
- `R`: Refresh/reload
- `F1-F12`: Function keys for advanced features
## Customization
### Themes
The TUI uses Textual's theming system and can be customized by modifying the CSS in each screen class.
### Adding New Screens
1. Create a new screen class inheriting from `BaseScreen`
2. Implement the `compose()` method for UI layout
3. Add event handlers for user interactions
4. Register the screen in the main app navigation
### Extending Functionality
- Add new widgets to existing screens
- Implement custom data visualization
- Integrate with external tools and APIs
- Add new keyboard shortcuts
## Troubleshooting
### Common Issues
1. **Import Errors**: Ensure textual and rich are installed
2. **Permission Errors**: Check file system permissions for config directories
3. **GPU Monitoring**: Install pynvml for GPU monitoring features
4. **Config Validation**: Ensure axolotl dependencies are properly installed
### Debug Mode
Launch with debug logging:
```bash
TEXTUAL_LOG=DEBUG python -m axolotl.cli.main tui
```
### Performance
- Use `Ctrl+\` to open Textual's debug console
- Monitor memory usage with the system monitor
- Disable auto-refresh for better performance on slower systems
## Contributing
The TUI is designed to be extensible. Contributions are welcome for:
- New screen implementations
- Enhanced visualizations
- Better keyboard navigation
- Additional integrations
- Performance improvements
See the main Axolotl repository for contribution guidelines.

View File

@@ -0,0 +1 @@
"""Axolotl Terminal User Interface (TUI)."""

180
src/axolotl/tui/app.py Normal file
View File

@@ -0,0 +1,180 @@
"""Main TUI application for Axolotl."""
from textual import on
from textual.app import App, ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.screen import Screen
from textual.widgets import Button, Footer, Header, Static
from axolotl.tui.screens.config import ConfigScreen
from axolotl.tui.screens.datasets import DatasetScreen
from axolotl.tui.screens.inference import InferenceScreen
from axolotl.tui.screens.models import ModelScreen
from axolotl.tui.screens.monitor import MonitorScreen
from axolotl.tui.screens.training import TrainingScreen
class WelcomeScreen(Screen):
"""Welcome screen with main menu."""
BINDINGS = [
Binding("q", "quit", "Quit"),
Binding("c", "config", "Configuration"),
Binding("t", "training", "Training"),
Binding("d", "datasets", "Datasets"),
Binding("m", "models", "Models"),
Binding("i", "inference", "Inference"),
Binding("s", "monitor", "System Monitor"),
]
def compose(self) -> ComposeResult:
"""Compose the welcome screen."""
yield Header()
yield Container(
Static("🦾 Axolotl TUI", classes="title"),
Static(
"A Terminal User Interface for fine-tuning LLMs", classes="subtitle"
),
Container(
Button("Configuration Management [C]", id="config", variant="primary"),
Button("Training Management [T]", id="training", variant="primary"),
Button("Dataset Management [D]", id="datasets", variant="primary"),
Button("Model Management [M]", id="models", variant="primary"),
Button("Inference & Testing [I]", id="inference", variant="primary"),
Button("System Monitor [S]", id="monitor", variant="primary"),
classes="menu-container",
),
classes="welcome-container",
)
yield Footer()
def action_quit(self) -> None:
"""Quit the application."""
self.app.exit()
def action_config(self) -> None:
"""Navigate to config screen."""
self.app.push_screen(ConfigScreen())
def action_training(self) -> None:
"""Navigate to training screen."""
self.app.push_screen(TrainingScreen())
def action_datasets(self) -> None:
"""Navigate to datasets screen."""
self.app.push_screen(DatasetScreen())
def action_models(self) -> None:
"""Navigate to models screen."""
self.app.push_screen(ModelScreen())
def action_inference(self) -> None:
"""Navigate to inference screen."""
self.app.push_screen(InferenceScreen())
def action_monitor(self) -> None:
"""Navigate to monitor screen."""
self.app.push_screen(MonitorScreen())
@on(Button.Pressed, "#config")
def on_config_pressed(self) -> None:
"""Handle config button press."""
self.action_config()
@on(Button.Pressed, "#training")
def on_training_pressed(self) -> None:
"""Handle training button press."""
self.action_training()
@on(Button.Pressed, "#datasets")
def on_datasets_pressed(self) -> None:
"""Handle datasets button press."""
self.action_datasets()
@on(Button.Pressed, "#models")
def on_models_pressed(self) -> None:
"""Handle models button press."""
self.action_models()
@on(Button.Pressed, "#inference")
def on_inference_pressed(self) -> None:
"""Handle inference button press."""
self.action_inference()
@on(Button.Pressed, "#monitor")
def on_monitor_pressed(self) -> None:
"""Handle monitor button press."""
self.action_monitor()
class AxolotlTUI(App):
"""Main Axolotl TUI Application."""
CSS = """
.title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.subtitle {
text-align: center;
padding: 1;
color: $text-muted;
}
.welcome-container {
align: center middle;
height: 100%;
width: 100%;
}
.menu-container {
layout: vertical;
align: center middle;
padding: 2;
width: auto;
height: auto;
}
.menu-container Button {
width: 35;
margin: 1;
}
WelcomeScreen {
align: center middle;
}
"""
BINDINGS = [
Binding("ctrl+q", "quit", "Quit", priority=True),
Binding("escape", "back", "Back", priority=True),
]
def on_mount(self) -> None:
"""Called when the app is mounted."""
self.title = "Axolotl TUI"
self.sub_title = "Fine-tuning LLMs made easy"
self.push_screen(WelcomeScreen())
def action_quit(self) -> None:
"""Quit the application."""
self.exit()
def action_back(self) -> None:
"""Go back to previous screen."""
if len(self.screen_stack) > 1:
self.pop_screen()
def run():
"""Run the Axolotl TUI application."""
app = AxolotlTUI()
app.run()
if __name__ == "__main__":
run()

View File

@@ -0,0 +1 @@
"""TUI dialogs for Axolotl."""

View File

@@ -0,0 +1,112 @@
"""Training dialogs for Axolotl TUI."""
from pathlib import Path
from textual import on
from textual.app import ComposeResult
from textual.containers import Container
from textual.screen import ModalScreen
from textual.widgets import Button, Input, Label, Select, Static
class NewTrainingDialog(ModalScreen):
"""Dialog for starting a new training job."""
CSS = """
NewTrainingDialog {
align: center middle;
}
.dialog-container {
background: $surface;
border: thick $primary;
padding: 2;
width: 60;
height: auto;
}
.dialog-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.form-field {
margin: 1 0;
}
.form-label {
margin: 0 0 1 0;
color: $text-muted;
}
.button-container {
layout: horizontal;
align: center middle;
margin: 2 0 0 0;
}
.button-container Button {
margin: 0 1;
}
"""
def compose(self) -> ComposeResult:
"""Compose the dialog."""
yield Container(
Static("Start New Training Job", classes="dialog-title"),
Container(
Label("Configuration File:", classes="form-label"),
Input(
placeholder="Path to config YAML file",
id="config-path",
value="/workspace/configs/",
),
classes="form-field",
),
Container(
Label("Launcher:", classes="form-label"),
Select(
[
("accelerate", "Accelerate (Recommended)"),
("torchrun", "TorchRun"),
("deepspeed", "DeepSpeed"),
],
id="launcher",
value="accelerate",
),
classes="form-field",
),
Container(
Button("Start Training", variant="primary", id="start"),
Button("Cancel", variant="default", id="cancel"),
classes="button-container",
),
classes="dialog-container",
)
@on(Button.Pressed, "#start")
def handle_start(self) -> None:
"""Handle start button press."""
config_input = self.query_one("#config-path", Input)
launcher_select = self.query_one("#launcher", Select)
config_path = config_input.value.strip()
if not config_path:
return
if not Path(config_path).exists():
return
result = {
"config_path": config_path,
"launcher": launcher_select.value,
}
self.dismiss(result)
@on(Button.Pressed, "#cancel")
def handle_cancel(self) -> None:
"""Handle cancel button press."""
self.dismiss(None)

View File

@@ -0,0 +1 @@
"""TUI screens for Axolotl."""

View File

@@ -0,0 +1,50 @@
"""Base screen class for Axolotl TUI screens."""
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.screen import Screen
from textual.widgets import Footer, Header, Static
class BaseScreen(Screen):
"""Base class for all Axolotl TUI screens."""
BINDINGS = [
Binding("escape", "back", "Back"),
Binding("q", "quit", "Quit"),
]
def __init__(self, title: str = "Axolotl", subtitle: str = ""):
"""Initialize the base screen.
Args:
title: The screen title
subtitle: Optional subtitle for the screen
"""
super().__init__()
self.screen_title = title
self.screen_subtitle = subtitle
def compose(self) -> ComposeResult:
"""Compose the base screen layout."""
yield Header()
yield Container(
Static(f"🦾 {self.screen_title}", classes="screen-title"),
(
Static(self.screen_subtitle, classes="screen-subtitle")
if self.screen_subtitle
else Static("")
),
Container(id="content"),
id="main-container",
)
yield Footer()
def action_back(self) -> None:
"""Go back to previous screen."""
self.app.pop_screen()
def action_quit(self) -> None:
"""Quit the application."""
self.app.exit()

View File

@@ -0,0 +1,376 @@
"""Configuration management screen for Axolotl TUI."""
import os
from pathlib import Path
from typing import Optional
import yaml
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.reactive import reactive
from textual.widgets import (
Button,
DirectoryTree,
Footer,
Header,
Label,
Log,
Static,
TextArea,
)
from axolotl.tui.screens.base import BaseScreen
class ConfigScreen(BaseScreen):
"""Configuration management screen."""
BINDINGS = [
Binding("ctrl+n", "new_config", "New Config"),
Binding("ctrl+o", "open_config", "Open Config"),
Binding("ctrl+s", "save_config", "Save Config"),
Binding("ctrl+v", "validate_config", "Validate"),
Binding("ctrl+e", "edit_mode", "Toggle Edit Mode"),
]
CSS = """
.config-container {
layout: horizontal;
height: 100%;
}
.file-browser {
width: 30%;
border: solid $primary;
padding: 1;
margin: 1;
}
.config-editor {
width: 70%;
border: solid $secondary;
padding: 1;
margin: 1;
}
.config-form {
height: 80%;
}
.config-actions {
layout: horizontal;
height: 3;
align: center middle;
padding: 1;
}
.config-actions Button {
margin: 0 1;
}
TextArea {
height: 100%;
}
.validation-log {
height: 20%;
border: solid $warning;
padding: 1;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
"""
def __init__(self):
"""Initialize the config screen."""
super().__init__(
title="Configuration Management",
subtitle="Create, edit, and validate Axolotl configurations",
)
self.current_config_path: Optional[Path] = None
self.edit_mode = reactive(False)
self.config_data = {}
def compose(self) -> ComposeResult:
"""Compose the config screen layout."""
yield Header()
yield Container(
Static("🦾 Configuration Management", classes="screen-title"),
Static(
"Create, edit, and validate Axolotl configurations",
classes="screen-subtitle",
),
Container(
Container(
Label("Config Files"),
DirectoryTree(
(
Path("/workspace/configs")
if Path("/workspace/configs").exists()
else Path.cwd()
),
id="config-tree",
),
classes="file-browser",
),
Container(
Container(
TextArea(
"",
language="yaml",
theme="monokai",
id="config-editor",
read_only=True,
),
classes="config-form",
),
Container(
Button("New", id="new-config", variant="primary"),
Button("Open", id="open-config", variant="primary"),
Button("Save", id="save-config", variant="success"),
Button("Validate", id="validate-config", variant="warning"),
Button("Edit Mode", id="toggle-edit", variant="default"),
Button("Load Example", id="load-example", variant="default"),
classes="config-actions",
),
Container(
Log(id="validation-log"),
classes="validation-log",
),
classes="config-editor",
),
classes="config-container",
),
id="content",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
tree = self.query_one("#config-tree", DirectoryTree)
tree.show_root = False
tree.guide_depth = 3
log = self.query_one("#validation-log", Log)
log.write_line("Ready to load configuration files...")
@on(DirectoryTree.FileSelected)
def handle_file_selected(self, event: DirectoryTree.FileSelected) -> None:
"""Handle file selection from the directory tree."""
if event.path.suffix in [".yaml", ".yml"]:
self.load_config_file(event.path)
def load_config_file(self, path: Path) -> None:
"""Load a configuration file."""
self.current_config_path = path
try:
with open(path, "r") as f:
content = f.read()
self.config_data = yaml.safe_load(content)
editor = self.query_one("#config-editor", TextArea)
editor.load_text(content)
log = self.query_one("#validation-log", Log)
log.clear()
log.write_line(f"✅ Loaded: {path.name}")
except Exception as e:
log = self.query_one("#validation-log", Log)
log.write_line(f"❌ Error loading {path.name}: {str(e)}")
@on(Button.Pressed, "#new-config")
def handle_new_config(self) -> None:
"""Create a new configuration."""
template = """# Axolotl Configuration
base_model:
model_type:
tokenizer_type:
# Dataset Configuration
datasets:
- path:
type:
# Training Configuration
output_dir: ./outputs
num_epochs: 3
micro_batch_size: 1
gradient_accumulation_steps: 4
learning_rate: 0.00002
warmup_steps: 100
eval_steps: 100
save_steps: 500
# LoRA Configuration (optional)
adapter: lora
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
# Training optimizations
gradient_checkpointing: true
flash_attention: true
bf16: auto
tf32: true
# Logging
logging_steps: 10
wandb_project:
wandb_entity:
"""
editor = self.query_one("#config-editor", TextArea)
editor.load_text(template)
editor.read_only = False
self.edit_mode = True
self.update_edit_button()
log = self.query_one("#validation-log", Log)
log.clear()
log.write_line("📝 New configuration created. Edit and save when ready.")
@on(Button.Pressed, "#save-config")
def handle_save_config(self) -> None:
"""Save the current configuration."""
editor = self.query_one("#config-editor", TextArea)
content = editor.text
if not content.strip():
log = self.query_one("#validation-log", Log)
log.write_line("⚠️ Cannot save empty configuration")
return
if not self.current_config_path:
default_path = Path("/workspace/configs/new_config.yaml")
default_path.parent.mkdir(parents=True, exist_ok=True)
self.current_config_path = default_path
try:
with open(self.current_config_path, "w") as f:
f.write(content)
log = self.query_one("#validation-log", Log)
log.write_line(f"💾 Saved: {self.current_config_path.name}")
except Exception as e:
log = self.query_one("#validation-log", Log)
log.write_line(f"❌ Error saving: {str(e)}")
@on(Button.Pressed, "#validate-config")
@work(thread=True)
async def handle_validate_config(self) -> None:
"""Validate the current configuration."""
editor = self.query_one("#config-editor", TextArea)
content = editor.text
if not content.strip():
log = self.query_one("#validation-log", Log)
log.write_line("⚠️ No configuration to validate")
return
log = self.query_one("#validation-log", Log)
log.clear()
log.write_line("🔍 Validating configuration...")
try:
import tempfile
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
f.write(content)
temp_path = f.name
from argparse import Namespace
from axolotl.cli.config import check_user_config
args = Namespace(
config=temp_path,
debug=False,
debug_text_only=False,
debug_num_examples=5,
accelerate_config=None,
multi_gpu=False,
)
check_user_config(args)
log.write_line("✅ Configuration is valid!")
os.unlink(temp_path)
except Exception as e:
log.write_line(f"❌ Validation failed: {str(e)}")
if "temp_path" in locals():
os.unlink(temp_path)
@on(Button.Pressed, "#toggle-edit")
def handle_toggle_edit(self) -> None:
"""Toggle edit mode for the configuration."""
editor = self.query_one("#config-editor", TextArea)
self.edit_mode = not self.edit_mode
editor.read_only = not self.edit_mode
self.update_edit_button()
log = self.query_one("#validation-log", Log)
if self.edit_mode:
log.write_line("✏️ Edit mode enabled")
else:
log.write_line("👁️ View mode enabled")
@on(Button.Pressed, "#load-example")
async def handle_load_example(self) -> None:
"""Load an example configuration."""
examples_dir = Path("/workspace/axolotl/examples")
if not examples_dir.exists():
log = self.query_one("#validation-log", Log)
log.write_line("⚠️ Examples directory not found")
return
yaml_files = list(examples_dir.glob("**/*.yml")) + list(
examples_dir.glob("**/*.yaml")
)
if yaml_files:
self.load_config_file(yaml_files[0])
log = self.query_one("#validation-log", Log)
log.write_line(f"📚 Loaded example: {yaml_files[0].name}")
def update_edit_button(self) -> None:
"""Update the edit button appearance."""
button = self.query_one("#toggle-edit", Button)
if self.edit_mode:
button.variant = "warning"
button.label = "Edit Mode: ON"
else:
button.variant = "default"
button.label = "Edit Mode: OFF"
def action_new_config(self) -> None:
"""Create a new configuration."""
self.handle_new_config()
def action_save_config(self) -> None:
"""Save the current configuration."""
self.handle_save_config()
def action_validate_config(self) -> None:
"""Validate the current configuration."""
self.handle_validate_config()
def action_edit_mode(self) -> None:
"""Toggle edit mode."""
self.handle_toggle_edit()

View File

@@ -0,0 +1,440 @@
"""Dataset management screen for Axolotl TUI."""
import json
from pathlib import Path
from typing import Dict, Optional
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Label,
Log,
ProgressBar,
Static,
TextArea,
)
from axolotl.tui.screens.base import BaseScreen
class DatasetScreen(BaseScreen):
"""Dataset management screen."""
BINDINGS = [
Binding("ctrl+p", "preprocess", "Preprocess"),
Binding("ctrl+v", "preview", "Preview"),
Binding("ctrl+i", "info", "Info"),
Binding("r", "refresh", "Refresh"),
]
CSS = """
.dataset-container {
layout: horizontal;
height: 100%;
}
.dataset-list {
width: 40%;
border: solid $primary;
padding: 1;
margin: 1;
}
.dataset-details {
width: 60%;
border: solid $secondary;
padding: 1;
margin: 1;
}
.dataset-actions {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
}
.dataset-actions Button {
margin: 0 1;
}
DataTable {
height: 100%;
}
.preview-container {
height: 100%;
border: solid $primary;
padding: 1;
}
TextArea {
height: 100%;
}
.stats-container {
layout: vertical;
padding: 1;
}
.stat-row {
layout: horizontal;
padding: 0 0 1 0;
}
.stat-label {
width: 50%;
color: $text-muted;
}
.stat-value {
width: 50%;
text-align: right;
text-style: bold;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
.progress-container {
padding: 1;
border: solid $warning;
margin: 1;
}
"""
def __init__(self):
"""Initialize the dataset screen."""
super().__init__(
title="Dataset Management",
subtitle="Browse, preview, and preprocess datasets",
)
self.datasets: Dict[str, Dict] = {}
self.selected_dataset: Optional[str] = None
self.preprocessing_active = False
def compose(self) -> ComposeResult:
"""Compose the dataset screen layout."""
yield Header()
yield Container(
Static("🦾 Dataset Management", classes="screen-title"),
Static(
"Browse, preview, and preprocess datasets", classes="screen-subtitle"
),
Container(
Container(
Label("Available Datasets"),
DataTable(id="dataset-table"),
Container(
Button("Load Dataset", id="load-dataset", variant="primary"),
Button("Preprocess", id="preprocess", variant="success"),
Button("Download", id="download", variant="default"),
Button("Refresh", id="refresh", variant="default"),
classes="dataset-actions",
),
classes="dataset-list",
),
Container(
TextArea("", id="dataset-preview", read_only=True),
Container(
Static("Dataset Name:", classes="stat-label"),
Static("-", id="stat-name", classes="stat-value"),
Static("Type:", classes="stat-label"),
Static("-", id="stat-type", classes="stat-value"),
Static("Size:", classes="stat-label"),
Static("-", id="stat-size", classes="stat-value"),
Static("Samples:", classes="stat-label"),
Static("-", id="stat-samples", classes="stat-value"),
Static("Features:", classes="stat-label"),
Static("-", id="stat-features", classes="stat-value"),
Static("Format:", classes="stat-label"),
Static("-", id="stat-format", classes="stat-value"),
Static("Preprocessed:", classes="stat-label"),
Static("-", id="stat-preprocessed", classes="stat-value"),
),
Log(id="processing-log"),
ProgressBar(total=100, id="preprocessing-progress"),
classes="dataset-details",
),
classes="dataset-container",
),
id="content",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.setup_dataset_table()
self.load_datasets()
log = self.query_one("#processing-log", Log)
log.write_line("Dataset manager ready.")
def setup_dataset_table(self) -> None:
"""Setup the dataset table."""
table = self.query_one("#dataset-table", DataTable)
table.add_columns("Name", "Type", "Size", "Status")
table.cursor_type = "row"
table.zebra_stripes = True
@work(thread=True)
async def load_datasets(self) -> None:
"""Load available datasets."""
# Check for local datasets
datasets_dir = Path("/workspace/datasets")
if datasets_dir.exists():
for dataset_path in datasets_dir.glob("*"):
if dataset_path.is_dir():
self.datasets[dataset_path.name] = {
"name": dataset_path.name,
"path": str(dataset_path),
"type": "local",
"size": self.get_dir_size(dataset_path),
"status": "available",
}
# Check for HuggingFace datasets in configs
configs_dir = Path("/workspace/configs")
if configs_dir.exists():
for config_file in configs_dir.glob("*.yaml"):
try:
import yaml
with open(config_file) as f:
config = yaml.safe_load(f)
if "datasets" in config:
for ds in config.get("datasets", []):
if "path" in ds:
ds_name = ds["path"].split("/")[-1]
self.datasets[ds_name] = {
"name": ds_name,
"path": ds["path"],
"type": ds.get("type", "huggingface"),
"size": "Unknown",
"status": "remote",
}
except Exception:
pass
self.refresh_dataset_table()
def get_dir_size(self, path: Path) -> str:
"""Get human-readable directory size."""
total_size = sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
for unit in ["B", "KB", "MB", "GB"]:
if total_size < 1024.0:
return f"{total_size:.2f} {unit}"
total_size /= 1024.0
return f"{total_size:.2f} TB"
def refresh_dataset_table(self) -> None:
"""Refresh the dataset table."""
table = self.query_one("#dataset-table", DataTable)
table.clear()
for name, info in self.datasets.items():
table.add_row(
name[:30],
info["type"],
info["size"],
info["status"],
)
@on(DataTable.RowSelected)
def handle_dataset_selected(self, event: DataTable.RowSelected) -> None:
"""Handle dataset selection from table."""
if event.cursor_row >= 0:
dataset_names = list(self.datasets.keys())
if event.cursor_row < len(dataset_names):
self.selected_dataset = dataset_names[event.cursor_row]
self.load_dataset_preview()
self.update_dataset_stats()
@work(thread=True)
async def load_dataset_preview(self) -> None:
"""Load preview of selected dataset."""
if not self.selected_dataset:
return
dataset_info = self.datasets[self.selected_dataset]
preview_text = ""
try:
if dataset_info["type"] == "local" and Path(dataset_info["path"]).exists():
# Load first few samples from local dataset
sample_files = list(Path(dataset_info["path"]).glob("*.json"))[:3]
samples = []
for sample_file in sample_files:
with open(sample_file) as f:
samples.append(json.load(f))
preview_text = json.dumps(samples, indent=2)
else:
# Show dataset info for remote datasets
preview_text = json.dumps(dataset_info, indent=2)
except Exception as e:
preview_text = f"Error loading preview: {str(e)}"
preview = self.query_one("#dataset-preview", TextArea)
preview.load_text(preview_text)
def update_dataset_stats(self) -> None:
"""Update dataset statistics display."""
if not self.selected_dataset:
return
info = self.datasets[self.selected_dataset]
self.query_one("#stat-name", Static).update(info["name"])
self.query_one("#stat-type", Static).update(info["type"])
self.query_one("#stat-size", Static).update(info["size"])
self.query_one("#stat-samples", Static).update("N/A")
self.query_one("#stat-features", Static).update("N/A")
self.query_one("#stat-format", Static).update("JSON")
self.query_one("#stat-preprocessed", Static).update("No")
@on(Button.Pressed, "#preprocess")
@work(thread=True)
async def handle_preprocess(self) -> None:
"""Preprocess selected dataset."""
if not self.selected_dataset or self.preprocessing_active:
return
self.preprocessing_active = True
dataset_info = self.datasets[self.selected_dataset]
log = self.query_one("#processing-log", Log)
log.clear()
log.write_line(f"🔄 Starting preprocessing for {self.selected_dataset}...")
progress = self.query_one("#preprocessing-progress", ProgressBar)
progress.update(progress=0)
try:
import subprocess
import tempfile
# Create a temporary config for preprocessing
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
config = {
"datasets": [
{
"path": dataset_info["path"],
"type": dataset_info.get("type", "alpaca"),
}
],
"output_dir": f"/tmp/preprocessed_{self.selected_dataset}",
}
import yaml
yaml.dump(config, f)
temp_config = f.name
# Run preprocessing
cmd = ["python", "-m", "axolotl.cli.preprocess", temp_config]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
# Monitor progress
for line in process.stdout:
log.write_line(line.strip())
# Update progress bar based on output
if "Processing" in line:
progress.advance(10)
process.wait()
if process.returncode == 0:
log.write_line("✅ Preprocessing completed successfully!")
dataset_info["status"] = "preprocessed"
progress.update(progress=100)
else:
log.write_line(
f"❌ Preprocessing failed with code {process.returncode}"
)
import os
os.unlink(temp_config)
except Exception as e:
log.write_line(f"❌ Error during preprocessing: {str(e)}")
finally:
self.preprocessing_active = False
self.refresh_dataset_table()
@on(Button.Pressed, "#load-dataset")
async def handle_load_dataset(self) -> None:
"""Load a new dataset."""
log = self.query_one("#processing-log", Log)
log.write_line("📦 Load dataset functionality coming soon...")
@on(Button.Pressed, "#download")
@work(thread=True)
async def handle_download(self) -> None:
"""Download a remote dataset."""
if not self.selected_dataset:
return
dataset_info = self.datasets[self.selected_dataset]
if dataset_info["type"] != "huggingface":
return
log = self.query_one("#processing-log", Log)
log.clear()
log.write_line(f"📥 Downloading {self.selected_dataset} from HuggingFace...")
try:
from datasets import load_dataset
dataset = load_dataset(dataset_info["path"])
save_path = Path(f"/workspace/datasets/{self.selected_dataset}")
save_path.mkdir(parents=True, exist_ok=True)
dataset.save_to_disk(str(save_path))
log.write_line(f"✅ Downloaded to {save_path}")
dataset_info["type"] = "local"
dataset_info["status"] = "available"
dataset_info["path"] = str(save_path)
self.refresh_dataset_table()
except Exception as e:
log.write_line(f"❌ Download failed: {str(e)}")
@on(Button.Pressed, "#refresh")
def handle_refresh(self) -> None:
"""Refresh dataset list."""
self.load_datasets()
def action_preprocess(self) -> None:
"""Preprocess selected dataset."""
self.handle_preprocess()
def action_refresh(self) -> None:
"""Refresh dataset list."""
self.handle_refresh()

View File

@@ -0,0 +1,445 @@
"""Inference and testing screen for Axolotl TUI."""
from pathlib import Path
from typing import Dict, List, Optional
from textual import events, on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.widgets import (
Button,
Input,
Label,
Log,
Select,
Static,
TextArea,
)
from axolotl.tui.screens.base import BaseScreen
class InferenceScreen(BaseScreen):
"""Inference and testing screen."""
BINDINGS = [
Binding("ctrl+enter", "send_message", "Send"),
Binding("ctrl+c", "clear_chat", "Clear"),
Binding("ctrl+l", "load_model", "Load Model"),
Binding("ctrl+s", "save_chat", "Save Chat"),
]
CSS = """
.inference-container {
layout: horizontal;
height: 100%;
}
.model-selector {
width: 30%;
border: solid $primary;
padding: 1;
margin: 1;
}
.chat-interface {
width: 70%;
border: solid $secondary;
padding: 1;
margin: 1;
}
.chat-history {
height: 70%;
border: solid $primary;
padding: 1;
margin: 0 0 1 0;
}
.input-area {
height: 20%;
border: solid $warning;
padding: 1;
margin: 0 0 1 0;
}
.chat-controls {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
}
.chat-controls Button {
margin: 0 1;
}
.model-info {
padding: 1;
border: solid $surface;
margin: 1 0;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
TextArea {
height: 100%;
}
Log {
height: 100%;
}
"""
def __init__(self):
"""Initialize the inference screen."""
super().__init__(
title="Inference & Testing", subtitle="Interactive chat and model testing"
)
self.loaded_model: Optional[str] = None
self.chat_history: List[Dict[str, str]] = []
def compose(self) -> ComposeResult:
"""Compose the inference screen layout."""
yield Container(
Static("🦾 Inference & Testing", classes="screen-title"),
Static("Interactive chat and model testing", classes="screen-subtitle"),
Container(
Container(
Label("Model Selection"),
Select(
[("No model loaded", "none")],
id="model-select",
value="none",
),
Container(
Button("Load Model", id="load-model", variant="primary"),
Button("Unload", id="unload-model", variant="default"),
Button("Gradio UI", id="gradio-ui", variant="success"),
),
Container(
Static("No model loaded", id="model-status"),
classes="model-info",
),
Label("Inference Parameters"),
Container(
Label("Temperature:"),
Input(value="0.7", id="temperature"),
Label("Max Tokens:"),
Input(value="256", id="max-tokens"),
Label("Top P:"),
Input(value="0.9", id="top-p"),
),
classes="model-selector",
),
Container(
Container(
Log(id="chat-history"),
classes="chat-history",
),
Container(
TextArea(
id="message-input",
),
classes="input-area",
),
Container(
Button("Send [Ctrl+Enter]", id="send", variant="primary"),
Button("Clear Chat", id="clear", variant="warning"),
Button("Save Chat", id="save-chat", variant="default"),
Button("Load Examples", id="load-examples", variant="default"),
classes="chat-controls",
),
classes="chat-interface",
),
classes="inference-container",
),
id="content",
)
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.load_available_models()
chat = self.query_one("#chat-history", Log)
chat.write_line("💬 Welcome to Axolotl Inference!")
chat.write_line("Load a model to start chatting.")
@work(thread=True)
async def load_available_models(self) -> None:
"""Load list of available models."""
models = [("No model loaded", "none")]
chat = self.query_one("#chat-history", Log)
chat.write_line("🔍 Scanning for available models...")
# Check for trained models
outputs_dir = Path("./outputs")
chat.write_line(f"Checking outputs directory: {outputs_dir.absolute()}")
if outputs_dir.exists():
found_models = 0
for model_dir in outputs_dir.glob("*"):
if model_dir.is_dir():
# Look for various model file types
model_files = (
list(model_dir.glob("pytorch_model.bin"))
+ list(model_dir.glob("model.safetensors"))
+ list(model_dir.glob("*.bin"))
+ list(model_dir.glob("*.safetensors"))
)
if model_files:
models.append((model_dir.name, str(model_dir)))
found_models += 1
chat.write_line(f"Found {found_models} trained models in outputs/")
else:
chat.write_line("outputs/ directory not found")
# Add some example/demo models for testing
models.extend(
[
("Demo: GPT-2 Small", "gpt2"),
("Demo: TinyLlama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"),
("Demo: Phi-2", "microsoft/phi-2"),
]
)
select = self.query_one("#model-select", Select)
select.set_options(models)
chat.write_line(f"✅ Loaded {len(models)} models in dropdown")
@on(Button.Pressed, "#load-model")
@work(thread=True)
async def handle_load_model(self) -> None:
"""Load selected model for inference."""
select = self.query_one("#model-select", Select)
if select.value == "none":
return
chat = self.query_one("#chat-history", Log)
chat.write_line(f"🔄 Loading model: {select.value}")
status = self.query_one("#model-status", Static)
status.update("Loading...")
try:
# Simulate model loading (in real implementation, would load the actual model)
import time
time.sleep(2) # Simulate loading time
self.loaded_model = select.value
status.update(f"✅ Loaded: {Path(select.value).name}")
chat.write_line("✅ Model loaded successfully!")
chat.write_line("You can now start chatting.")
except Exception as e:
status.update("❌ Failed to load")
chat.write_line(f"❌ Failed to load model: {str(e)}")
@on(Button.Pressed, "#send")
async def handle_send_message(self) -> None:
"""Send message to model."""
if not self.loaded_model:
chat = self.query_one("#chat-history", Log)
chat.write_line("⚠️ Please load a model first")
return
message_input = self.query_one("#message-input", TextArea)
message = message_input.text.strip()
if not message:
return
# Add user message to chat
chat = self.query_one("#chat-history", Log)
chat.write_line(f"👤 User: {message}")
# Clear input
message_input.clear()
# Add to history
self.chat_history.append({"role": "user", "content": message})
# Generate response (placeholder)
self.generate_response(message)
@on(TextArea.Changed, "#message-input")
def on_message_input_changed(self, event: TextArea.Changed) -> None:
"""Handle changes to the message input."""
# This could be used for features like typing indicators
pass
def on_key(self, event: events.Key) -> None:
"""Handle key events globally."""
# Check if we're focused on the message input and Ctrl+Enter is pressed
focused = self.focused
if focused and focused.id == "message-input" and event.key == "ctrl+enter":
event.prevent_default()
self.handle_send_message()
@work(thread=True)
async def generate_response(self, message: str) -> None:
"""Generate model response."""
chat = self.query_one("#chat-history", Log)
chat.write_line("🤖 Assistant: Thinking...")
try:
# Get inference parameters
float(self.query_one("#temperature", Input).value)
int(self.query_one("#max-tokens", Input).value)
float(self.query_one("#top-p", Input).value)
if not self.loaded_model or self.loaded_model == "none":
response = "I don't have a model loaded yet. Please load a model first using the 'Load Model' button."
elif self.loaded_model.startswith("gpt2"):
# Simple response for GPT-2
responses = [
f"Thanks for your message: '{message}'. I'm a GPT-2 model running in demo mode.",
"I understand you're testing the interface. GPT-2 models are great for experimentation!",
"This is a simulated GPT-2 response. In a real setup, I'd generate text based on your input.",
f"GPT-2 here! You said: '{message}'. I'd normally continue this conversation creatively.",
]
import random
response = random.choice(responses)
elif "llama" in self.loaded_model.lower():
# Response for Llama models
response = f"🦙 LLaMA model here! You asked: '{message}'. I'm designed for helpful, harmless, and honest conversations. How can I assist you today?"
elif "phi" in self.loaded_model.lower():
# Response for Phi models
response = f"Phi model responding! Your message: '{message}'. I'm optimized for reasoning and code tasks. What would you like to explore?"
else:
# Generic response for other models
response = f"Model '{self.loaded_model}' responding to: '{message}'. I'm ready to help with your questions!"
# Simulate inference time
import time
time.sleep(0.5)
# Clear the "thinking" message and show response
chat.write_line(f"🤖 Assistant: {response}")
# Add to history
self.chat_history.append({"role": "assistant", "content": response})
except Exception as e:
chat.write_line(f"❌ Error generating response: {str(e)}")
@on(Button.Pressed, "#clear")
def handle_clear_chat(self) -> None:
"""Clear chat history."""
chat = self.query_one("#chat-history", Log)
chat.clear()
self.chat_history = []
chat.write_line("💬 Chat cleared. Start a new conversation!")
@on(Button.Pressed, "#save-chat")
def handle_save_chat(self) -> None:
"""Save chat history to file."""
if not self.chat_history:
chat = self.query_one("#chat-history", Log)
chat.write_line("⚠️ No chat history to save")
return
try:
import json
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"chat_history_{timestamp}.json"
with open(filename, "w") as f:
json.dump(self.chat_history, f, indent=2)
chat = self.query_one("#chat-history", Log)
chat.write_line(f"💾 Chat saved to {filename}")
except Exception as e:
chat = self.query_one("#chat-history", Log)
chat.write_line(f"❌ Error saving chat: {str(e)}")
@on(Button.Pressed, "#load-examples")
def handle_load_examples(self) -> None:
"""Load example prompts."""
examples = [
"Explain the concept of machine learning in simple terms.",
"Write a Python function to calculate fibonacci numbers.",
"What are the benefits of fine-tuning language models?",
"Describe the difference between supervised and unsupervised learning.",
]
chat = self.query_one("#chat-history", Log)
chat.write_line("📚 Example prompts:")
for i, example in enumerate(examples, 1):
chat.write_line(f"{i}. {example}")
chat.write_line("Copy and paste any example to try it out!")
@on(Button.Pressed, "#gradio-ui")
@work(thread=True)
async def handle_gradio_ui(self) -> None:
"""Launch Gradio web interface."""
chat = self.query_one("#chat-history", Log)
chat.write_line("🌐 Launching Gradio web interface...")
try:
import subprocess
if self.loaded_model:
cmd = [
"python",
"-m",
"axolotl.cli.inference",
self.loaded_model,
"--gradio",
]
else:
chat.write_line("⚠️ No model loaded. Loading default interface...")
cmd = ["python", "-m", "axolotl.cli.inference", "--gradio"]
subprocess.Popen(cmd)
chat.write_line("✅ Gradio interface launched! Check your browser.")
except Exception as e:
chat.write_line(f"❌ Error launching Gradio: {str(e)}")
@on(Button.Pressed, "#unload-model")
def handle_unload_model(self) -> None:
"""Unload current model."""
self.loaded_model = None
status = self.query_one("#model-status", Static)
status.update("No model loaded")
select = self.query_one("#model-select", Select)
select.value = "none"
chat = self.query_one("#chat-history", Log)
chat.write_line("🔄 Model unloaded")
def action_send_message(self) -> None:
"""Send message action."""
self.handle_send_message()
def action_clear_chat(self) -> None:
"""Clear chat action."""
self.handle_clear_chat()
def action_load_model(self) -> None:
"""Load model action."""
self.handle_load_model()
def action_save_chat(self) -> None:
"""Save chat action."""
self.handle_save_chat()

View File

@@ -0,0 +1,373 @@
"""Model management screen for Axolotl TUI."""
from pathlib import Path
from typing import Dict, Optional
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container, ScrollableContainer
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Label,
Log,
ProgressBar,
Static,
TabbedContent,
TabPane,
)
from axolotl.tui.screens.base import BaseScreen
class ModelScreen(BaseScreen):
"""Model management screen."""
BINDINGS = [
Binding("ctrl+m", "merge_lora", "Merge LoRA"),
Binding("ctrl+q", "quantize", "Quantize"),
Binding("ctrl+e", "evaluate", "Evaluate"),
Binding("r", "refresh", "Refresh"),
]
CSS = """
.model-container {
layout: horizontal;
height: 100%;
}
.model-list {
width: 50%;
border: solid $primary;
padding: 1;
margin: 1;
}
.model-operations {
width: 50%;
border: solid $secondary;
padding: 1;
margin: 1;
}
.model-actions {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
}
.model-actions Button {
margin: 0 1;
}
DataTable {
height: 80%;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
"""
def __init__(self):
"""Initialize the model screen."""
super().__init__(
title="Model Management",
subtitle="Manage trained models, merge LoRA adapters, and quantize models",
)
self.models: Dict[str, Dict] = {}
self.selected_model: Optional[str] = None
def compose(self) -> ComposeResult:
"""Compose the model screen layout."""
yield Header()
with Container(id="content"):
yield Static("🦾 Model Management", classes="screen-title")
yield Static(
"Manage trained models, merge LoRA adapters, and quantize models",
classes="screen-subtitle",
)
with Container(classes="model-container"):
with Container(classes="model-list"):
yield Label("Available Models")
yield DataTable(id="model-table")
with Container(classes="model-actions"):
yield Button("Merge LoRA", id="merge-lora", variant="primary")
yield Button("Quantize", id="quantize", variant="success")
yield Button("Evaluate", id="evaluate", variant="warning")
yield Button("Refresh", id="refresh", variant="default")
with Container(classes="model-operations"):
with TabbedContent():
with TabPane("Operations"):
with Container():
yield Log(id="operations-log")
with Container():
yield Label("Operation Progress:")
yield ProgressBar(
total=100,
id="operation-progress",
)
with TabPane("Model Info"):
with ScrollableContainer():
yield Static(
"Model information will appear here",
id="model-info",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.setup_model_table()
self.load_models()
log = self.query_one("#operations-log", Log)
log.write_line("Model manager ready.")
def setup_model_table(self) -> None:
"""Setup the model table."""
table = self.query_one("#model-table", DataTable)
table.add_columns("Name", "Type", "Size", "Status")
table.cursor_type = "row"
table.zebra_stripes = True
@work(thread=True)
async def load_models(self) -> None:
"""Load available models."""
# Check outputs directory for trained models
outputs_dir = Path("./outputs")
if outputs_dir.exists():
for model_dir in outputs_dir.glob("*"):
if model_dir.is_dir():
self.models[model_dir.name] = {
"name": model_dir.name,
"path": str(model_dir),
"type": "checkpoint",
"size": self.get_dir_size(model_dir),
"status": "available",
}
self.refresh_model_table()
def get_dir_size(self, path: Path) -> str:
"""Get human-readable directory size."""
try:
total_size = sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
for unit in ["B", "KB", "MB", "GB"]:
if total_size < 1024.0:
return f"{total_size:.2f} {unit}"
total_size /= 1024.0
return f"{total_size:.2f} TB"
except Exception:
return "Unknown"
def refresh_model_table(self) -> None:
"""Refresh the model table."""
table = self.query_one("#model-table", DataTable)
table.clear()
for name, info in self.models.items():
table.add_row(
name[:30],
info["type"],
info["size"],
info["status"],
)
@on(DataTable.RowSelected)
def handle_model_selected(self, event: DataTable.RowSelected) -> None:
"""Handle model selection from table."""
if event.cursor_row >= 0:
model_names = list(self.models.keys())
if event.cursor_row < len(model_names):
self.selected_model = model_names[event.cursor_row]
self.update_model_info()
def update_model_info(self) -> None:
"""Update model information display."""
if not self.selected_model:
return
info = self.models[self.selected_model]
info_text = f"""
Model Name: {info['name']}
Path: {info['path']}
Type: {info['type']}
Size: {info['size']}
Status: {info['status']}
"""
self.query_one("#model-info", Static).update(info_text)
@on(Button.Pressed, "#merge-lora")
@work(thread=True)
async def handle_merge_lora(self) -> None:
"""Merge LoRA adapters with base model."""
if not self.selected_model:
log = self.query_one("#operations-log", Log)
log.write_line("⚠️ No model selected")
return
model_info = self.models[self.selected_model]
log = self.query_one("#operations-log", Log)
log.clear()
log.write_line(f"🔄 Merging LoRA adapters for {self.selected_model}...")
progress = self.query_one("#operation-progress", ProgressBar)
progress.update(progress=0)
try:
import subprocess
cmd = ["python", "-m", "axolotl.cli.merge_lora", model_info["path"]]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
for line in process.stdout:
log.write_line(line.strip())
progress.advance(10)
process.wait()
if process.returncode == 0:
log.write_line("✅ LoRA merge completed successfully!")
progress.update(progress=100)
else:
log.write_line(f"❌ LoRA merge failed with code {process.returncode}")
except Exception as e:
log.write_line(f"❌ Error during LoRA merge: {str(e)}")
@on(Button.Pressed, "#quantize")
@work(thread=True)
async def handle_quantize(self) -> None:
"""Quantize selected model."""
if not self.selected_model:
log = self.query_one("#operations-log", Log)
log.write_line("⚠️ No model selected")
return
model_info = self.models[self.selected_model]
log = self.query_one("#operations-log", Log)
log.clear()
log.write_line(f"🔄 Quantizing {self.selected_model}...")
progress = self.query_one("#operation-progress", ProgressBar)
progress.update(progress=0)
try:
import subprocess
cmd = [
"python",
"-m",
"axolotl.cli.quantize",
model_info["path"],
"--output-dir",
f"{model_info['path']}_quantized",
]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
for line in process.stdout:
log.write_line(line.strip())
progress.advance(5)
process.wait()
if process.returncode == 0:
log.write_line("✅ Quantization completed successfully!")
progress.update(progress=100)
else:
log.write_line(f"❌ Quantization failed with code {process.returncode}")
except Exception as e:
log.write_line(f"❌ Error during quantization: {str(e)}")
@on(Button.Pressed, "#evaluate")
@work(thread=True)
async def handle_evaluate(self) -> None:
"""Evaluate selected model."""
if not self.selected_model:
log = self.query_one("#operations-log", Log)
log.write_line("⚠️ No model selected")
return
model_info = self.models[self.selected_model]
log = self.query_one("#operations-log", Log)
log.clear()
log.write_line(f"🔄 Evaluating {self.selected_model}...")
progress = self.query_one("#operation-progress", ProgressBar)
progress.update(progress=0)
try:
import subprocess
cmd = ["python", "-m", "axolotl.cli.evaluate", model_info["path"]]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
for line in process.stdout:
log.write_line(line.strip())
progress.advance(10)
process.wait()
if process.returncode == 0:
log.write_line("✅ Evaluation completed successfully!")
progress.update(progress=100)
else:
log.write_line(f"❌ Evaluation failed with code {process.returncode}")
except Exception as e:
log.write_line(f"❌ Error during evaluation: {str(e)}")
@on(Button.Pressed, "#refresh")
def handle_refresh(self) -> None:
"""Refresh model list."""
self.load_models()
def action_merge_lora(self) -> None:
"""Merge LoRA adapters."""
self.handle_merge_lora()
def action_quantize(self) -> None:
"""Quantize model."""
self.handle_quantize()
def action_evaluate(self) -> None:
"""Evaluate model."""
self.handle_evaluate()
def action_refresh(self) -> None:
"""Refresh model list."""
self.handle_refresh()

View File

@@ -0,0 +1,414 @@
"""System monitoring screen for Axolotl TUI."""
import psutil
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Label,
Log,
ProgressBar,
Sparkline,
Static,
)
from axolotl.tui.screens.base import BaseScreen
class MonitorScreen(BaseScreen):
"""System monitoring screen."""
BINDINGS = [
Binding("r", "refresh", "Refresh"),
Binding("ctrl+k", "kill_process", "Kill Process"),
]
CSS = """
.monitor-container {
layout: vertical;
height: 100%;
}
.metrics-grid {
layout: horizontal;
height: 20%;
padding: 1;
}
.metric-card {
width: 25%;
border: solid $surface;
padding: 1;
margin: 0 1;
}
.metric-label {
text-style: bold;
color: $text-muted;
text-align: center;
}
.metric-value {
text-style: bold;
text-align: center;
padding: 1;
}
.charts-container {
height: 40%;
layout: horizontal;
padding: 1;
}
.chart-panel {
width: 50%;
border: solid $primary;
padding: 1;
margin: 0 1;
}
.processes-container {
height: 40%;
border: solid $warning;
padding: 1;
margin: 1;
}
DataTable {
height: 90%;
}
.process-controls {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
}
.process-controls Button {
margin: 0 1;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
Sparkline {
height: 8;
}
ProgressBar {
margin: 1 0;
}
"""
def __init__(self):
"""Initialize the monitor screen."""
super().__init__(
title="System Monitor",
subtitle="Monitor system resources and running processes",
)
self.cpu_history = []
self.memory_history = []
self.gpu_history = []
def compose(self) -> ComposeResult:
"""Compose the monitor screen layout."""
yield Header()
yield Container(
Static("🦾 System Monitor", classes="screen-title"),
Static(
"Monitor system resources and running processes",
classes="screen-subtitle",
),
Container(
Container(
Container(
Static("CPU Usage", classes="metric-label"),
Static("0%", id="cpu-usage", classes="metric-value"),
ProgressBar(total=100, id="cpu-progress"),
classes="metric-card",
),
Container(
Static("Memory", classes="metric-label"),
Static("0%", id="memory-usage", classes="metric-value"),
ProgressBar(total=100, id="memory-progress"),
classes="metric-card",
),
Container(
Static("GPU Usage", classes="metric-label"),
Static("0%", id="gpu-usage", classes="metric-value"),
ProgressBar(total=100, id="gpu-progress"),
classes="metric-card",
),
Container(
Static("Temperature", classes="metric-label"),
Static("0°C", id="temperature", classes="metric-value"),
classes="metric-card",
),
classes="metrics-grid",
),
Container(
Container(
Label("CPU History"),
Sparkline([], id="cpu-sparkline"),
classes="chart-panel",
),
Container(
Label("Memory History"),
Sparkline([], id="memory-sparkline"),
classes="chart-panel",
),
classes="charts-container",
),
Container(
DataTable(id="process-table"),
Log(id="gpu-info"),
Log(id="system-logs"),
classes="processes-container",
),
classes="monitor-container",
),
id="content",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.setup_process_table()
self.start_monitoring()
# Initial system info
self.update_system_info()
self.update_gpu_info()
def setup_process_table(self) -> None:
"""Setup the process table."""
table = self.query_one("#process-table", DataTable)
table.add_columns("PID", "Name", "CPU%", "Memory%", "Status")
table.cursor_type = "row"
table.zebra_stripes = True
def start_monitoring(self) -> None:
"""Start the monitoring timer."""
self.set_interval(2.0, self.update_system_metrics)
@work(thread=True)
async def update_system_metrics(self) -> None:
"""Update system metrics."""
try:
# CPU usage
cpu_percent = psutil.cpu_percent(interval=None)
self.cpu_history.append(cpu_percent)
if len(self.cpu_history) > 50:
self.cpu_history.pop(0)
# Memory usage
memory = psutil.virtual_memory()
memory_percent = memory.percent
self.memory_history.append(memory_percent)
if len(self.memory_history) > 50:
self.memory_history.pop(0)
# GPU usage (if available)
gpu_percent = self.get_gpu_usage()
self.gpu_history.append(gpu_percent)
if len(self.gpu_history) > 50:
self.gpu_history.pop(0)
# Temperature
temperature = self.get_temperature()
# Update UI
self.update_metrics_display(
cpu_percent, memory_percent, gpu_percent, temperature
)
self.update_sparklines()
self.update_process_table()
except Exception as e:
log = self.query_one("#system-logs", Log)
log.write_line(f"Error updating metrics: {str(e)}")
def get_gpu_usage(self) -> float:
"""Get GPU usage percentage."""
try:
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
return util.gpu
except Exception:
return 0.0
def get_temperature(self) -> str:
"""Get system temperature."""
try:
temps = psutil.sensors_temperatures()
if temps:
for name, entries in temps.items():
if entries:
return f"{entries[0].current:.1f}°C"
return "N/A"
except Exception:
return "N/A"
def update_metrics_display(
self, cpu: float, memory: float, gpu: float, temp: str
) -> None:
"""Update metrics display."""
self.query_one("#cpu-usage", Static).update(f"{cpu:.1f}%")
self.query_one("#memory-usage", Static).update(f"{memory:.1f}%")
self.query_one("#gpu-usage", Static).update(f"{gpu:.1f}%")
self.query_one("#temperature", Static).update(temp)
self.query_one("#cpu-progress", ProgressBar).update(progress=cpu)
self.query_one("#memory-progress", ProgressBar).update(progress=memory)
self.query_one("#gpu-progress", ProgressBar).update(progress=gpu)
def update_sparklines(self) -> None:
"""Update sparkline charts."""
if self.cpu_history:
cpu_sparkline = self.query_one("#cpu-sparkline", Sparkline)
cpu_sparkline.data = self.cpu_history
if self.memory_history:
memory_sparkline = self.query_one("#memory-sparkline", Sparkline)
memory_sparkline.data = self.memory_history
def update_process_table(self) -> None:
"""Update the process table."""
table = self.query_one("#process-table", DataTable)
table.clear()
try:
# Get top processes by CPU usage
processes = []
for proc in psutil.process_iter(
["pid", "name", "cpu_percent", "memory_percent", "status"]
):
try:
pinfo = proc.info
if pinfo["cpu_percent"] > 0.1: # Only show processes using CPU
processes.append(pinfo)
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
# Sort by CPU usage
processes.sort(key=lambda x: x["cpu_percent"], reverse=True)
# Add top 20 processes
for proc in processes[:20]:
table.add_row(
str(proc["pid"]),
proc["name"][:20],
f"{proc['cpu_percent']:.1f}%",
f"{proc['memory_percent']:.1f}%",
proc["status"],
)
except Exception as e:
log = self.query_one("#system-logs", Log)
log.write_line(f"Error updating process table: {str(e)}")
def update_system_info(self) -> None:
"""Update system information."""
try:
# System info
psutil.boot_time()
cpu_count = psutil.cpu_count()
memory = psutil.virtual_memory()
log = self.query_one("#system-logs", Log)
log.write_line(f"System started. CPU cores: {cpu_count}")
log.write_line(f"Total memory: {memory.total / (1024**3):.1f} GB")
log.write_line(f"Available memory: {memory.available / (1024**3):.1f} GB")
except Exception as e:
log = self.query_one("#system-logs", Log)
log.write_line(f"Error getting system info: {str(e)}")
def update_gpu_info(self) -> None:
"""Update GPU information."""
try:
import pynvml
pynvml.nvmlInit()
device_count = pynvml.nvmlDeviceGetCount()
log = self.query_one("#gpu-info", Log)
log.clear()
log.write_line(f"Found {device_count} GPU(s)")
for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
name = pynvml.nvmlDeviceGetName(handle).decode()
memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
log.write_line(f"\nGPU {i}: {name}")
log.write_line(
f"Memory: {memory_info.used / (1024**3):.1f} / {memory_info.total / (1024**3):.1f} GB"
)
log.write_line(f"Free: {memory_info.free / (1024**3):.1f} GB")
except Exception as e:
log = self.query_one("#gpu-info", Log)
log.clear()
log.write_line(f"GPU info unavailable: {str(e)}")
@on(Button.Pressed, "#kill-process")
def handle_kill_process(self) -> None:
"""Kill selected process."""
table = self.query_one("#process-table", DataTable)
if table.cursor_row >= 0:
try:
row = table.get_row_at(table.cursor_row)
pid = int(row[0])
process = psutil.Process(pid)
process.terminate()
log = self.query_one("#system-logs", Log)
log.write_line(f"Terminated process {pid}")
except Exception as e:
log = self.query_one("#system-logs", Log)
log.write_line(f"Error killing process: {str(e)}")
@on(Button.Pressed, "#refresh")
def handle_refresh(self) -> None:
"""Refresh all metrics."""
self.update_system_info()
self.update_gpu_info()
log = self.query_one("#system-logs", Log)
log.write_line("Metrics refreshed")
@on(Button.Pressed, "#auto-refresh")
def handle_auto_refresh(self) -> None:
"""Toggle auto refresh."""
log = self.query_one("#system-logs", Log)
log.write_line("Auto refresh is always enabled (every 2 seconds)")
def action_refresh(self) -> None:
"""Refresh action."""
self.handle_refresh()
def action_kill_process(self) -> None:
"""Kill process action."""
self.handle_kill_process()

View File

@@ -0,0 +1,545 @@
"""Training management screen for Axolotl TUI."""
import subprocess
import threading
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Label,
Log,
Sparkline,
Static,
)
from axolotl.tui.screens.base import BaseScreen
@dataclass
class TrainingJob:
"""Represents a training job."""
id: str
config_path: str
status: str # pending, running, completed, failed
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
process: Optional[subprocess.Popen] = None
log_file: Optional[str] = None
current_epoch: int = 0
total_epochs: int = 0
current_loss: float = 0.0
losses: List[float] = None
def __post_init__(self):
if self.losses is None:
self.losses = []
class TrainingScreen(BaseScreen):
"""Training management screen."""
BINDINGS = [
Binding("ctrl+t", "new_training", "New Training"),
Binding("ctrl+r", "resume_training", "Resume"),
Binding("ctrl+x", "stop_training", "Stop"),
Binding("ctrl+l", "view_logs", "View Logs"),
Binding("r", "refresh", "Refresh"),
]
CSS = """
.training-container {
layout: vertical;
height: 100%;
}
.job-list-container {
height: 40%;
border: solid $primary;
padding: 1;
margin: 1;
}
.job-details-container {
height: 60%;
padding: 1;
}
.control-panel {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
border: solid $secondary;
margin: 1;
}
.control-panel Button {
margin: 0 1;
}
.metrics-panel {
layout: horizontal;
height: 10;
border: solid $primary;
padding: 1;
margin: 1;
}
.metric-card {
width: 25%;
border: tall $surface;
padding: 1;
margin: 0 1;
}
.metric-label {
text-style: bold;
color: $text-muted;
}
.metric-value {
text-style: bold;
text-align: center;
padding: 1;
}
.log-viewer {
border: solid $warning;
padding: 1;
margin: 1;
}
#training-logs {
height: 100%;
}
DataTable {
height: 100%;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
.sparkline-container {
height: 5;
border: solid $success;
padding: 1;
margin: 1;
}
"""
def __init__(self):
"""Initialize the training screen."""
super().__init__(
title="Training Management",
subtitle="Launch, monitor, and manage training jobs",
)
self.jobs: Dict[str, TrainingJob] = {}
self.selected_job_id: Optional[str] = None
self.update_timer = None
def compose(self) -> ComposeResult:
"""Compose the training screen layout."""
yield Header()
yield Container(
Static("🦾 Training Management", classes="screen-title"),
Static(
"Launch, monitor, and manage training jobs", classes="screen-subtitle"
),
Container(
Container(
Label("Active Training Jobs"),
DataTable(id="job-table"),
classes="job-list-container",
),
Container(
Button("New Training", id="new-training", variant="primary"),
Button("Resume", id="resume-training", variant="success"),
Button("Stop", id="stop-training", variant="error"),
Button("View Logs", id="view-logs", variant="default"),
Button("Clear Completed", id="clear-completed", variant="warning"),
Button("Refresh", id="refresh", variant="default"),
classes="control-panel",
),
Container(
Container(
Static("Current Epoch", classes="metric-label"),
Static("0 / 0", id="epoch-metric", classes="metric-value"),
classes="metric-card",
),
Container(
Static("Loss", classes="metric-label"),
Static("0.000", id="loss-metric", classes="metric-value"),
classes="metric-card",
),
Container(
Static("Status", classes="metric-label"),
Static("Idle", id="status-metric", classes="metric-value"),
classes="metric-card",
),
Container(
Static("Duration", classes="metric-label"),
Static(
"00:00:00", id="duration-metric", classes="metric-value"
),
classes="metric-card",
),
classes="metrics-panel",
),
Container(
Label("Loss History"),
Sparkline(
[],
id="loss-sparkline",
summary_function=min,
),
classes="sparkline-container",
),
Container(
Log(id="training-logs"),
classes="log-viewer",
),
classes="job-details-container",
),
classes="training-container",
id="content",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.setup_job_table()
self.start_update_timer()
log = self.query_one("#training-logs", Log)
log.write_line(
"Training manager ready. Select a configuration to start training."
)
def setup_job_table(self) -> None:
"""Setup the job table."""
table = self.query_one("#job-table", DataTable)
table.add_columns("ID", "Config", "Status", "Epoch", "Loss", "Duration")
table.cursor_type = "row"
table.zebra_stripes = True
def start_update_timer(self) -> None:
"""Start the periodic update timer."""
self.set_interval(2.0, self.update_job_status)
@work(thread=True)
async def update_job_status(self) -> None:
"""Update job status periodically."""
for job_id, job in self.jobs.items():
if job.status == "running" and job.process:
poll = job.process.poll()
if poll is not None:
if poll == 0:
job.status = "completed"
else:
job.status = "failed"
job.end_time = datetime.now()
self.refresh_job_table()
self.update_selected_job_metrics()
def refresh_job_table(self) -> None:
"""Refresh the job table."""
table = self.query_one("#job-table", DataTable)
table.clear()
for job_id, job in self.jobs.items():
duration = self.calculate_duration(job)
table.add_row(
job_id[:8],
Path(job.config_path).name,
job.status,
f"{job.current_epoch}/{job.total_epochs}",
f"{job.current_loss:.4f}" if job.current_loss else "N/A",
duration,
)
def calculate_duration(self, job: TrainingJob) -> str:
"""Calculate job duration."""
if not job.start_time:
return "00:00:00"
end_time = job.end_time or datetime.now()
duration = end_time - job.start_time
hours = int(duration.total_seconds() // 3600)
minutes = int((duration.total_seconds() % 3600) // 60)
seconds = int(duration.total_seconds() % 60)
return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
def update_selected_job_metrics(self) -> None:
"""Update metrics for selected job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
return
job = self.jobs[self.selected_job_id]
self.query_one("#epoch-metric", Static).update(
f"{job.current_epoch} / {job.total_epochs}"
)
self.query_one("#loss-metric", Static).update(
f"{job.current_loss:.4f}" if job.current_loss else "N/A"
)
self.query_one("#status-metric", Static).update(job.status.upper())
self.query_one("#duration-metric", Static).update(self.calculate_duration(job))
if job.losses:
sparkline = self.query_one("#loss-sparkline", Sparkline)
sparkline.data = job.losses[-50:] # Show last 50 loss values
@on(DataTable.RowSelected)
def handle_row_selected(self, event: DataTable.RowSelected) -> None:
"""Handle job selection from table."""
if event.cursor_row >= 0:
job_ids = list(self.jobs.keys())
if event.cursor_row < len(job_ids):
self.selected_job_id = job_ids[event.cursor_row]
self.update_selected_job_metrics()
self.load_job_logs()
def load_job_logs(self) -> None:
"""Load logs for selected job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
return
job = self.jobs[self.selected_job_id]
if job.log_file and Path(job.log_file).exists():
try:
with open(job.log_file, "r") as f:
content = f.read()
log = self.query_one("#training-logs", Log)
log.clear()
for line in content.split("\n")[-100:]: # Show last 100 lines
if line.strip():
log.write_line(line)
except Exception as e:
log = self.query_one("#training-logs", Log)
log.write_line(f"Error loading logs: {str(e)}")
@on(Button.Pressed, "#new-training")
async def handle_new_training(self) -> None:
"""Start a new training job."""
from axolotl.tui.dialogs.training import NewTrainingDialog
dialog = NewTrainingDialog()
result = await self.app.push_screen_wait(dialog)
if result and "config_path" in result:
await self.start_training_job(
result["config_path"], result.get("launcher", "accelerate")
)
@work(thread=True)
async def start_training_job(
self, config_path: str, launcher: str = "accelerate"
) -> None:
"""Start a training job."""
import uuid
from datetime import datetime
job_id = str(uuid.uuid4())
log_file = f"/tmp/axolotl_training_{job_id}.log"
job = TrainingJob(
id=job_id,
config_path=config_path,
status="pending",
start_time=datetime.now(),
log_file=log_file,
total_epochs=3, # Default, should parse from config
)
self.jobs[job_id] = job
self.selected_job_id = job_id
log = self.query_one("#training-logs", Log)
log.clear()
log.write_line(f"🚀 Starting training job {job_id[:8]}...")
log.write_line(f"Config: {config_path}")
log.write_line(f"Launcher: {launcher}")
try:
if launcher == "accelerate":
cmd = ["accelerate", "launch", "-m", "axolotl.cli.train", config_path]
else:
cmd = [
"torchrun",
"--nproc_per_node=1",
"-m",
"axolotl.cli.train",
config_path,
]
with open(log_file, "w") as f:
process = subprocess.Popen(
cmd,
stdout=f,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
)
job.process = process
job.status = "running"
log.write_line("✅ Training started successfully!")
self.refresh_job_table()
self.monitor_training_output(job_id)
except Exception as e:
job.status = "failed"
job.end_time = datetime.now()
log.write_line(f"❌ Failed to start training: {str(e)}")
self.refresh_job_table()
def monitor_training_output(self, job_id: str) -> None:
"""Monitor training output and extract metrics."""
if job_id not in self.jobs:
return
job = self.jobs[job_id]
if not job.log_file:
return
def tail_log():
import re
import time
with open(job.log_file, "r") as f:
f.seek(0, 2) # Go to end of file
while job.status == "running":
line = f.readline()
if line:
# Parse training metrics from log
epoch_match = re.search(r"Epoch (\d+)/(\d+)", line)
if epoch_match:
job.current_epoch = int(epoch_match.group(1))
job.total_epochs = int(epoch_match.group(2))
loss_match = re.search(
r"loss['\"]?\s*:\s*([\d.]+)", line, re.IGNORECASE
)
if loss_match:
job.current_loss = float(loss_match.group(1))
job.losses.append(job.current_loss)
# Update log viewer
self.call_from_thread(self.append_training_log, line.strip())
else:
time.sleep(0.5)
thread = threading.Thread(target=tail_log, daemon=True)
thread.start()
def append_training_log(self, line: str) -> None:
"""Append line to training log."""
log = self.query_one("#training-logs", Log)
log.write_line(line)
@on(Button.Pressed, "#stop-training")
def handle_stop_training(self) -> None:
"""Stop selected training job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
log = self.query_one("#training-logs", Log)
log.write_line("⚠️ No job selected")
return
job = self.jobs[self.selected_job_id]
if job.status == "running" and job.process:
job.process.terminate()
job.status = "stopped"
job.end_time = datetime.now()
log = self.query_one("#training-logs", Log)
log.write_line(f"🛑 Training job {job.id[:8]} stopped")
self.refresh_job_table()
@on(Button.Pressed, "#resume-training")
async def handle_resume_training(self) -> None:
"""Resume a stopped training job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
log = self.query_one("#training-logs", Log)
log.write_line("⚠️ No job selected")
return
job = self.jobs[self.selected_job_id]
if job.status in ["stopped", "failed"]:
await self.start_training_job(job.config_path)
@on(Button.Pressed, "#clear-completed")
def handle_clear_completed(self) -> None:
"""Clear completed jobs from the list."""
completed_jobs = [
job_id
for job_id, job in self.jobs.items()
if job.status in ["completed", "failed", "stopped"]
]
for job_id in completed_jobs:
del self.jobs[job_id]
self.refresh_job_table()
log = self.query_one("#training-logs", Log)
log.write_line(f"🧹 Cleared {len(completed_jobs)} completed jobs")
@on(Button.Pressed, "#refresh")
def handle_refresh(self) -> None:
"""Refresh the job list and metrics."""
self.refresh_job_table()
self.update_selected_job_metrics()
if self.selected_job_id:
self.load_job_logs()
@on(Button.Pressed, "#view-logs")
def handle_view_logs(self) -> None:
"""View full logs for selected job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
return
job = self.jobs[self.selected_job_id]
if job.log_file and Path(job.log_file).exists():
import subprocess
subprocess.run(["less", job.log_file])
def action_new_training(self) -> None:
"""Start a new training job."""
self.handle_new_training()
def action_stop_training(self) -> None:
"""Stop selected training job."""
self.handle_stop_training()
def action_resume_training(self) -> None:
"""Resume selected training job."""
self.handle_resume_training()
def action_refresh(self) -> None:
"""Refresh the display."""
self.handle_refresh()

View File

@@ -109,12 +109,6 @@ class AxolotlInputConfig(
"description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs" "description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs"
}, },
) )
reinit_weights: bool | None = Field(
default=None,
json_schema_extra={
"description": "Reinitialize model weights randomly instead of loading pretrained weights"
},
)
trainer_cls: str | None = Field( trainer_cls: str | None = Field(
default=None, default=None,

View File

@@ -1,119 +0,0 @@
"""E2E smoke test for diffusion training plugin."""
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists
class TestDiffusion:
"""Test case for diffusion training plugin."""
def test_diffusion_smoke_test(self, temp_dir):
"""
Smoke test for diffusion training to ensure the plugin loads and trains without
error.
"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"trust_remote_code": True,
"sequence_len": 256,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 3,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
"logging_steps": 1,
"eval_steps": 3,
# Diffusion-specific config
"plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"],
"diffusion_mask_token_id": 16,
"diffusion_eps": 1e-3,
"diffusion_importance_weighting": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
def test_diffusion_sft_labels(self, temp_dir):
"""Test that diffusion training properly handles SFT data with labels."""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"trust_remote_code": True,
"sequence_len": 256,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 3,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
"logging_steps": 1,
"eval_steps": 2,
# Diffusion-specific config
"plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"],
"diffusion_mask_token_id": 16,
"diffusion_eps": 1e-3,
"diffusion_importance_weighting": True,
# Ensure we have proper SFT labels
"train_on_inputs": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
# Verify that the dataset has labels
sample = dataset_meta.train_dataset[0]
assert "labels" in sample, "SFT dataset should have labels"
# Check that some labels are -100 (prompt tokens)
labels = sample["labels"]
if hasattr(labels, "tolist"):
labels = labels.tolist()
assert -100 in labels, "SFT dataset should have -100 labels for prompt tokens"
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -1,290 +0,0 @@
"""Tests for diffusion model integration."""
# pylint: disable=redefined-outer-name,protected-access
from unittest.mock import Mock, patch
import pytest
import torch
from axolotl.integrations.diffusion.configuration import LlamaForDiffusionConfig
from axolotl.integrations.diffusion.models import LlamaForDiffusionLM
from axolotl.utils.dict import DictDefault
@pytest.fixture
def mock_tokenizer():
"""Create a mock tokenizer."""
tokenizer = Mock()
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
tokenizer.pad_token_id = 0
return tokenizer
@pytest.fixture
def diffusion_config():
"""Create a diffusion config."""
return LlamaForDiffusionConfig(
mask_token_id=32000,
eps=1e-3,
importance_weighting=False,
sample_packing=False,
# Basic llama config fields - smaller for testing
vocab_size=1000,
hidden_size=256,
intermediate_size=512,
num_hidden_layers=2,
num_attention_heads=4,
)
@pytest.fixture
def diffusion_model_instance(mock_tokenizer, diffusion_config):
"""Create a diffusion model instance for testing methods directly."""
# Create a minimal model instance for testing
model = object.__new__(LlamaForDiffusionLM)
model.config = diffusion_config
model._special_token_ids = {0, 1, 2} # pad, bos, eos
model.training = True
# Set tokenizer
model.set_tokenizer(mock_tokenizer)
return model
class TestDiffusionModel:
"""Test the DiffusionModel class."""
def test_forward_process_basic(self, diffusion_model_instance):
"""Test basic forward process without labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
noisy_batch, masked_indices, p_mask = (
diffusion_model_instance._forward_process(input_ids, eps=0.1)
)
# Check shapes
assert noisy_batch.shape == input_ids.shape
assert masked_indices.shape == input_ids.shape
assert p_mask.shape == input_ids.shape
# Check that special tokens are not masked
special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0)
assert not masked_indices[special_token_positions].any()
# Check that mask token is applied
mask_token_id = diffusion_model_instance.config.mask_token_id
masked_positions = masked_indices
if masked_positions.any():
assert (noisy_batch[masked_positions] == mask_token_id).all()
def test_forward_process_with_labels(self, diffusion_model_instance):
"""Test forward process with SFT labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
noisy_batch, masked_indices, p_mask = (
diffusion_model_instance._forward_process(
input_ids, labels=labels, eps=0.1
)
)
# Check shapes
assert noisy_batch.shape == input_ids.shape
assert masked_indices.shape == input_ids.shape
assert p_mask.shape == input_ids.shape
# Check that only answer tokens can be masked (where labels != -100)
non_answer_mask = labels == -100
# No masking should occur on non-answer tokens
assert not masked_indices[non_answer_mask].any()
# p_mask should be the same for all positions (sampled timestep),
# but masking is only applied to answer tokens
assert p_mask.shape == input_ids.shape
# Verify that masked_indices respects the answer mask
assert not masked_indices[non_answer_mask].any()
def test_forward_process_with_attention_mask(self, diffusion_model_instance):
"""Test forward process with attention mask."""
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
_, masked_indices, p_mask = diffusion_model_instance._forward_process(
input_ids, attention_mask=attention_mask, eps=0.1
)
# Check that padding tokens are not masked
padding_positions = attention_mask == 0
assert not masked_indices[padding_positions].any()
assert (p_mask[padding_positions] == 0).all()
def test_bidirectional_attention_mask_no_packing(self, diffusion_model_instance):
"""Test bidirectional attention mask without sample packing."""
input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long)
mask = diffusion_model_instance._create_bidirectional_attention_mask(
input_ids
)
# Should be all-to-all attention
expected_shape = (1, 1, 4, 4)
assert mask.shape == expected_shape
assert mask.all()
def test_bidirectional_attention_mask_with_packing(
self, diffusion_model_instance
):
"""Test bidirectional attention mask with sample packing."""
diffusion_model_instance.config.sample_packing = True
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
# Sample IDs: first sample (1), second sample (2)
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
mask = diffusion_model_instance._create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Check that tokens within same sample can attend to each other
# but not across samples
assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other
assert mask[0, 0, 1, 2].item()
assert not mask[0, 0, 0, 3].item() # Can't attend across samples
assert not mask[0, 0, 2, 4].item()
assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other
def test_compute_loss_basic(self, diffusion_model_instance):
"""Test basic loss computation."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
# Create mock data for loss computation
vocab_size = 1000
seq_len = 5
logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
# Create a simple masked indices tensor (mask middle tokens)
masked_indices = torch.tensor([[False, True, True, False, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.5, 0.5, 0.1, 0.1]], dtype=torch.float)
loss = diffusion_model_instance._compute_diffusion_loss(
input_ids=input_ids,
logits=logits,
masked_indices=masked_indices,
p_mask=p_mask,
)
# Check that loss is computed
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
def test_compute_loss_with_labels(self, diffusion_model_instance):
"""Test loss computation with SFT labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
# Create mock data for loss computation
vocab_size = 1000
seq_len = 5
logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
# Create masked indices that only covers answer tokens
masked_indices = torch.tensor([[False, False, True, True, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.1, 0.5, 0.5, 0.1]], dtype=torch.float)
loss = diffusion_model_instance._compute_diffusion_loss(
input_ids=input_ids,
labels=labels,
logits=logits,
masked_indices=masked_indices,
p_mask=p_mask,
)
# Check that loss is computed
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
def test_compute_loss_no_masked_tokens(self, diffusion_model_instance):
"""Test loss computation when no tokens are masked."""
input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)
# Create mock data for loss computation
vocab_size = 1000
seq_len = 3
logits = torch.randn(1, seq_len, vocab_size)
# No tokens masked
masked_indices = torch.tensor([[False, False, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float)
loss = diffusion_model_instance._compute_diffusion_loss(
input_ids=input_ids,
logits=logits,
masked_indices=masked_indices,
p_mask=p_mask,
)
# Loss should be zero when no tokens are masked
assert loss.item() == 0.0
assert loss.requires_grad
def test_cache_special_token_ids(self, diffusion_model_instance):
"""Test caching of special token IDs."""
# Should cache BOS, EOS, PAD tokens
expected_tokens = {0, 1, 2} # pad, bos, eos
assert diffusion_model_instance._special_token_ids == expected_tokens
def test_cache_special_token_ids_no_tokenizer(self):
"""Test caching when no tokenizer is available."""
# Mock the parent model initialization to avoid loading pretrained weights
with patch('transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__'):
model = LlamaForDiffusionLM.__new__(LlamaForDiffusionLM)
model._cache_special_token_ids(None)
assert model._special_token_ids == set()
def test_forward_training_mode(self, diffusion_model_instance):
"""Test forward pass in training mode."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.bool)
# Mock the parent forward method
with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward:
mock_output = Mock()
mock_output.logits = torch.randn(1, 5, 32000)
mock_forward.return_value = mock_output
# Set training mode
diffusion_model_instance.training = True
result = diffusion_model_instance.forward(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
# Should call parent forward and compute loss
assert mock_forward.called
assert hasattr(result, 'loss')
def test_forward_inference_mode(self, diffusion_model_instance):
"""Test forward pass in inference mode."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
# Mock the parent forward method
with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward:
mock_output = Mock()
mock_forward.return_value = mock_output
# Set inference mode
diffusion_model_instance.training = False
result = diffusion_model_instance.forward(
input_ids=input_ids,
return_dict=True
)
# Should just call parent forward without diffusion processing
assert mock_forward.called
assert result == mock_output