Compare commits

..

12 Commits

Author SHA1 Message Date
Wing Lian
cf8c93e2ee wip 2025-08-19 09:36:57 -04:00
Dan Saunders
63d2280999 nits 2025-08-18 19:17:24 +00:00
Dan Saunders
b210db2d15 fixes 2025-08-18 19:09:09 +00:00
Dan Saunders
556a69118f sample generation, tests fixes 2025-08-18 18:25:04 +00:00
Dan Saunders
8569675b26 Merge branch 'main' into diffusion 2025-08-18 10:07:55 -04:00
Dan Saunders
077b5a4358 cleanup; tests draft 2025-08-16 02:44:44 +00:00
Dan Saunders
234b7b3126 nits 2025-08-16 00:14:44 +00:00
Dan Saunders
e19be0c2d9 add back in reinit_weights (clobbered?); masking / pretrain fixes 2025-08-15 02:21:25 +00:00
Dan Saunders
479a454ae3 fixes + improvements 2025-08-14 16:11:37 -04:00
Dan Saunders
0a9341acde nits 2025-08-14 01:53:24 -04:00
Dan Saunders
d8b63804bc cleanup 2025-08-14 01:51:13 -04:00
Dan Saunders
3156c605d4 diffusion training plugin 2025-08-14 01:48:22 -04:00
52 changed files with 11576 additions and 13229 deletions

View File

@@ -12,6 +12,5 @@ reviews:
auto_review:
enabled: true
drafts: false
auto_incremental_review: true
chat:
auto_reply: true

File diff suppressed because it is too large Load Diff

View File

@@ -41,12 +41,6 @@ model, and final model output, you may need at least 3TB of free disk space to k
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
```
To simplify fine-tuning across 2 nodes × 8x H100 (80GB) GPUs, we've partnered with [Baseten](https://baseten.co) to showcase multi-node
training of the 120B model using Baseten Truss. You can read more about this recipe on
[Baseten's blog](https://www.baseten.co/blog/how-to-fine-tune-gpt-oss-120b-with-baseten-and-axolotl/). The recipe can
be found on their
[GitHub](https://github.com/basetenlabs/ml-cookbook/tree/main/examples/oss-gpt-120b-axolotl/training).
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
@@ -67,23 +61,9 @@ mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
### Inferencing your fine-tuned model
#### vLLM
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
for more information about using a special vllm-openai docker image for inferencing with vLLM.
Optionally, vLLM can be installed from nightly:
```bash
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
```
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
```bash
vllm serve ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-20b --host 0.0.0.0 --port 8888 --tensor-parallel-size 8
```
#### SGLang
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:

View File

@@ -44,7 +44,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -40,7 +40,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: ./outputs/last_run_prepared
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
@@ -41,7 +41,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: ./outputs/last_run_prepared
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
@@ -40,7 +40,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -53,7 +53,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -0,0 +1,57 @@
base_model: meta-llama/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
pretraining_dataset:
- path: wikitext
name: wikitext-103-raw-v1
type: completion
field: text
plugins:
- diffusion.DiffusionPlugin
noise_schedule: cosine
min_mask_ratio: 0.15
max_mask_ratio: 0.85
eps: 5e-4
importance_weighting: true
mask_token_id: 128002
generate_samples: true
generation_interval: 10
output_dir: ./outputs/model-out
sequence_len: 512
sample_packing: true
gradient_accumulation_steps: 8
micro_batch_size: 4
max_steps: 10000
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 3e-4
bf16: auto
tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
sdp_attention: true
warmup_steps: 1000
save_strategy: steps
save_steps: 1000
special_tokens:
pad_token: "<|end_of_text|>"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,58 @@
base_model: meta-llama/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
val_set_size: 0.05
plugins:
- diffusion.DiffusionPlugin
noise_schedule: cosine
min_mask_ratio: 0.1
max_mask_ratio: 0.9
num_diffusion_steps: 128
eps: 1e-3
importance_weighting: true
mask_token_id: 128002
output_dir: ./outputs/model-out
sequence_len: 512
sample_packing: true
eval_sample_packing: true
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 1e-5
bf16: auto
tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
sdp_attention: true
warmup_steps: 1000
save_strategy: steps
eval_strategy: steps
save_steps: 500
eval_steps: 500
special_tokens:
pad_token: "<|end_of_text|>"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -72,8 +72,3 @@ axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.5
mistral-common==1.8.3
# TUI dependencies
textual==1.0.0
rich==14.1.0
tree_sitter_ruby==0.23.1

View File

@@ -118,9 +118,9 @@ def get_package_version():
extras_require = {
"flash-attn": ["flash-attn==2.8.3"],
"flash-attn": ["flash-attn==2.8.2"],
"ring-flash-attn": [
"flash-attn==2.8.3",
"flash-attn==2.8.2",
"ring-flash-attn>=0.1.7",
"yunchang==0.6.0",
],

View File

@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
return res
def get_image(self):
docker_tag = "main-py3.11-cu126-2.7.1"
docker_tag = "main-py3.11-cu124-2.6.0"
if self.config.docker_tag:
docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}"
@@ -200,7 +200,7 @@ class ModalCloud(Cloud):
if family in ["a10", "a10g"]:
return modal.gpu.A10G(count=count)
if family == "h100":
return f"H100:{count}"
return modal.gpu.H100(count=count)
if family == "t4":
return modal.gpu.T4(count=count)
if family == "l4":

View File

@@ -64,7 +64,7 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer

View File

@@ -344,26 +344,6 @@ def delinearize_llama4(model: str, output: str):
cli.add_command(lm_eval)
@cli.command()
def tui():
"""
Launch the Axolotl Terminal User Interface (TUI).
Provides an interactive interface for configuration management,
training monitoring, dataset handling, and model operations.
"""
try:
from axolotl.tui.app import run
run()
except ImportError:
click.echo(
"TUI dependencies not installed. Install with: pip install textual rich"
)
except Exception as e:
click.echo(f"Error launching TUI: {e}")
def main():
cli()

View File

@@ -97,8 +97,7 @@ def do_cli(
"""
# pylint: disable=duplicate-code
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
is_preprocess = kwargs.pop("is_preprocess", True)
parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs)
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -3,12 +3,11 @@
import random
from copy import deepcopy
from itertools import product
from typing import Any
def generate_sweep_configs(
base_config: dict[str, list], sweeps_config: dict[str, list]
) -> list[dict[str, Any]]:
) -> list[dict[str, list]]:
"""
Recursively generates all possible configurations by applying sweeps to the base config.

View File

@@ -4,7 +4,6 @@ import os
import subprocess # nosec
import sys
import tempfile
from pathlib import Path
from typing import Any, Iterator, Literal
import yaml
@@ -89,12 +88,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str,
# Generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
is_group = len(permutations) > 1
base_output_dir = base_config.get("output_dir", "./model-out")
for idx, permutation in enumerate(permutations, start=1):
permutation_dir = Path(permutation.get("output_dir", base_output_dir))
permutation_id = f"sweep{idx:04d}"
permutation["output_dir"] = str(permutation_dir / permutation_id)
for permutation in permutations:
# pylint: disable=consider-using-with
temp_file = tempfile.NamedTemporaryFile(
mode="w",

View File

@@ -10,6 +10,7 @@ import transformers
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
Trainer,
)
from trl.trainer.utils import RewardDataCollatorWithPadding
@@ -385,10 +386,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**data_collator_kwargs,
)
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters:
if "processing_class" in sig.parameters or issubclass(trainer_cls, Trainer):
trainer_kwargs["processing_class"] = self.tokenizer
elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
if (
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
and self.cfg.datasets is not None

View File

@@ -82,7 +82,9 @@ class AxolotlTrainer(
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
self._stored_metrics = defaultdict(
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
)
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
@@ -272,6 +274,18 @@ class AxolotlTrainer(
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
if (self.args.accelerator_config is not None
and self.args.accelerator_config.split_batches
and self.args.accelerator_config.dispatch_batches
):
if self.args.sample_packing and self.args.pretraining:
if not self.args.eval_sample_packing and not is_training:
dataloader_params["batch_size"] *= self.accelerator.num_processes
else:
dataloader_params["batch_size"] = self.accelerator.num_processes
elif not self.args.sample_packing and self.args.pretraining:
dataloader_params["batch_size"] *= self.accelerator.num_processes
if self.args.sample_packing and (
(is_training and not self.args.pretraining)
or (not is_training and self.args.eval_sample_packing is not False)
@@ -573,9 +587,26 @@ class AxolotlTrainer(
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
# Add reduced stored metrics to logs
for key, metric_data in self._stored_metrics[train_eval].items():
values = torch.tensor(metric_data["values"])
reduction_type = metric_data["reduction"]
if reduction_type == "mean":
logs[key] = values.mean().item()
elif reduction_type == "min":
logs[key] = values.min().item()
elif reduction_type == "max":
logs[key] = values.max().item()
elif reduction_type == "sum":
logs[key] = values.sum().item()
else:
raise NotImplementedError(
"Metric reduction must be one of [mean, min, max, sum]"
)
logs[key] = round(logs[key], 4)
if is_main_process():
# Add memory usage
@@ -592,10 +623,27 @@ class AxolotlTrainer(
return super().log(logs, start_time)
def store_metrics(
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
self,
metrics: dict[str, float] | dict[str, tuple[int | float, str]],
train_eval: Literal["train", "eval"] = "train",
reduction: Literal["mean", "min", "max", "sum"] = "mean",
) -> None:
"""
Store metrics with specified reduction type.
Args:
metrics: Dictionary of metric names to values, or metric names to (value,
reduction_type) tuples.
train_eval: Whether this is for training or evaluation.
"""
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
if isinstance(value, tuple):
metric_value, metric_reduction = value
else:
metric_value, metric_reduction = value, reduction
self._stored_metrics[train_eval][key]["values"].append(metric_value)
self._stored_metrics[train_eval][key]["reduction"] = metric_reduction
def _save_checkpoint(self, model, trial, **kwargs):
# make sure the checkpoint dir exists, since trainer is flakey

View File

@@ -147,7 +147,7 @@ class BasePlugin:
"""
# pylint: disable=unused-argument
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:
"""Returns a custom class for the trainer.
Args:

View File

@@ -0,0 +1,125 @@
# Diffusion LM Training Plugin for Axolotl
This plugin enables diffusion language model training using the LLaDA (Large Language
And Diffusion Assistant) approach within the Axolotl framework.
## Overview
LLaDA is a diffusion-based approach to language model training that uses:
- **Random token masking** during training instead of next-token prediction
- **Bidirectional attention** to allow the model to see the full context
- **Importance weighting** based on masking probabilities for stable training
This approach can lead to more robust language models with better understanding of
bidirectional context.
## Installation
The plugin is included with Axolotl. To use it, simply add the plugin configuration to
your training config.
## Quickstart
### Basic Configuration
Add the following to your Axolotl configuration YAML:
```yaml
# Enable diffusion LM training plugin
plugins:
- axolotl.integrations.diffusion.DiffusionPlugin
# Diffusion-specific configuration
noise_schedule: linear # or "cosine"
min_mask_ratio: 0.1
max_mask_ratio: 0.9
num_diffusion_steps: 128
eps: 1e-3
importance_weighting: true
mask_token_id: 128002
# Sample generation (optional)
generate_samples: true
generation_interval: 100
num_generation_samples: 3
generation_steps: 128
generation_temperature: 0.0
generation_max_length: 100
# Model configuration
base_model: meta-llama/Llama-3.2-1B
model_type: llama
# Standard Axolotl configuration
datasets:
- path: your_dataset
...
# Other config
sequence_len: 1024
micro_batch_size: 8
gradient_accumulation_steps: 4
learning_rate: 3e-4
```
## Supported Models
Any models that support 4D attention masks should work out of the box. If not, please
create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues)!
## How It Works
### Random Masking
During training, tokens are randomly masked based on a sampled timestep:
- Sample timestep `t` uniformly from [0, 1]
- Calculate masking probability: `p = (1 - eps) * t + eps`
- Randomly mask tokens with probability `p`
### Bidirectional Attention
The plugin uses native 4D attention masks to:
- Enable bidirectional attention without patches
- Allow all tokens to attend to all other tokens
- Maintain proper padding masks
- Work with modern `transformers` models out of the box
### Diffusion Loss
Loss is computed only on masked tokens with (optional) importance weighting:
```python
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
```
## Sample Generation
When `generate_samples: true`, the plugin generates samples during training:
```
Sample 1:
Original (45 tokens): The quick brown fox jumps over the lazy dog...
Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...
Generated: The quick brown fox jumps over the lazy dog...
```
Samples are logged to console and wandb (if enabled).
## Metrics and Monitoring
The plugin adds several metrics to track diffusion training:
- `train/loss`: Weighted diffusion loss
- `train/accuracy`: Accuracy on masked tokens
- `train/mask_ratio`: Average fraction of tokens masked
- `train/num_masked_tokens`: Number of tokens masked
- `train/avg_p_mask`: Average masking probability
- `train/ce_loss`: Unweighted cross-entropy loss
- `train/importance_weight_avg`: Average importance weight
## Limitations
- No flash attention support
## References
- [LLaDA Paper](https://arxiv.org/abs/2404.10406)
- [Axolotl Documentation](https://docs.axolotl.ai/)

View File

@@ -0,0 +1,6 @@
"""Diffusion LM training plugin init."""
from .args import DiffusionArgs
from .plugin import DiffusionPlugin
__all__ = ["DiffusionArgs", "DiffusionPlugin"]

View File

@@ -0,0 +1,70 @@
"""Config args for diffusion LM training."""
from typing import Literal
from pydantic import BaseModel, Field
class DiffusionArgs(BaseModel):
"""Arguments for diffusion LM training plugin."""
# Noise schedule config
noise_schedule: Literal["linear", "cosine"] = Field(
default="linear", description="Type of noise schedule for diffusion training"
)
min_mask_ratio: float = Field(
default=0.1,
ge=0.0,
le=1.0,
description="Minimum masking ratio for diffusion noise schedule",
)
max_mask_ratio: float = Field(
default=0.9,
ge=0.0,
le=1.0,
description="Maximum masking ratio for diffusion noise schedule",
)
num_diffusion_steps: int = Field(
default=128, ge=1, description="Number of diffusion timesteps"
)
eps: float = Field(
default=1e-3,
ge=0.0,
le=1.0,
description="Epsilon value for minimum masking probability in forward process",
)
# Training config
importance_weighting: bool = Field(
default=True,
description="Apply importance weighting to loss based on masking probability",
)
mask_token_id: int = Field(
default=128002,
description=(
"Token ID to use for masking. Default is 128002 "
"(<|reserved_special_token_0|> for Llama 3.2)"
),
)
# Sample generation config
generate_samples: bool = Field(
default=True, description="Enable sample generation during training"
)
generation_interval: int = Field(
default=100, ge=1, description="Generate samples every N steps"
)
num_generation_samples: int = Field(
default=3, ge=1, description="Number of samples to generate each time"
)
generation_steps: int = Field(
default=128, ge=1, description="Number of diffusion steps for generation"
)
generation_temperature: float = Field(
default=0.0,
ge=0.0,
description="Temperature for generation sampling (0.0 = deterministic)",
)
generation_max_length: int = Field(
default=100, ge=1, description="Maximum sequence length for generation"
)

View File

@@ -0,0 +1,113 @@
"""Callbacks for diffusion training."""
import wandb
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from axolotl.utils.logging import get_logger
from .generation import generate_samples
LOG = get_logger(__name__)
class DiffusionGenerationCallback(TrainerCallback):
"""Callback for generating samples during diffusion training."""
def __init__(self, trainer):
self.trainer = trainer
# pylint: disable=unused-argument
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Generate samples at specified intervals."""
if (
state.global_step > 0
and state.global_step % self.trainer.config.generation_interval == 0
):
# Use eval dataloader if available, otherwise use train dataloader
if (
hasattr(self.trainer, "eval_dataset")
and self.trainer.eval_dataset is not None
):
dataloader = self.trainer.callback_handler.eval_dataloader
else:
dataloader = self.trainer.callback_handler.train_dataloader
# Generate samples
samples = generate_samples(
model=self.trainer.model,
tokenizer=self.trainer.tokenizer,
dataloader=dataloader,
num_generation_samples=self.trainer.config.num_generation_samples,
max_length=self.trainer.config.generation_max_length,
num_diffusion_steps=self.trainer.config.generation_steps,
temperature=self.trainer.config.generation_temperature,
mask_token_id=self.trainer.config.mask_token_id,
)
# Log samples
self._log_samples(samples, state.global_step)
def _log_samples(self, samples: list, step: int):
"""Log generated samples."""
if not samples:
return
LOG.info("=" * 60)
LOG.info("GENERATED SAMPLES")
LOG.info("=" * 60)
for i, sample_data in enumerate(samples, 1):
original = sample_data["original"]
masked = sample_data["masked"]
generated = sample_data["generated"]
mask_ratio = sample_data["mask_ratio"]
masked_tokens = sample_data["masked_tokens"]
total_tokens = sample_data["total_tokens"]
LOG.info(f"\nSample {i}:")
LOG.info(f"\tOriginal ({total_tokens} tokens): {original}")
LOG.info(
f"\tMasked ({masked_tokens}/{total_tokens} tokens, "
f"{mask_ratio:.1%}): {masked}"
)
LOG.info(f"\tGenerated: {generated}")
LOG.info("=" * 60)
if self.trainer.config.use_wandb and self.trainer.state.is_world_process_zero:
if wandb.run is not None:
wandb.log(
{
"generated_samples": wandb.Table(
columns=[
"step",
"original",
"masked",
"generated",
"mask_ratio",
"masked_tokens",
"total_tokens",
],
data=[
[
step,
sample["original"],
sample["masked"],
sample["generated"],
f"{sample['mask_ratio']:.1%}",
sample["masked_tokens"],
sample["total_tokens"],
]
for sample in samples
],
)
},
step=step,
)

View File

@@ -0,0 +1,269 @@
"""Sample generation utilities for diffusion training."""
import logging
from typing import Any, List, Optional
import torch
logger = logging.getLogger(__name__)
def generate_samples(
model: torch.nn.Module,
tokenizer: Any,
dataloader: Optional[Any] = None,
num_generation_samples: int = 3,
max_length: int = 100,
num_diffusion_steps: int = 128,
temperature: float = 0.0,
mask_token_id: int = 32000,
) -> List[dict]:
"""
Generate text samples using the diffusion model by randomly masking sequences from
the given dataset and running the reverse diffusion process.
Args:
model: The wrapped or unwrapped model
tokenizer: Tokenizer for encoding/decoding
dataloader: Validation dataloader (for sampling sequences)
num_generation_samples: Number of samples to generate
max_length: Maximum length of sequences to use
num_diffusion_steps: Number of diffusion steps for generation
temperature: Temperature for sampling (0.0 = deterministic)
mask_token_id: Token ID used for masking
Returns:
List of dictionaries with original text, masked text, and generated text
"""
if dataloader is None:
logger.warning("No validation dataloader provided, cannot generate samples")
return []
# Get the actual model (unwrap if needed)
unwrapped_model = model.module if hasattr(model, "module") else model
unwrapped_model.eval()
generations = []
# Sample sequences from validation dataset
sampled_sequences = _sample_sequences_from_dataloader(
dataloader, num_generation_samples, max_length, unwrapped_model.device
)
logger.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset")
# Generate samples using reverse diffusion process
with torch.no_grad():
for original_sequence in sampled_sequences:
generation_result = _generate(
unwrapped_model,
tokenizer,
original_sequence,
num_diffusion_steps,
temperature,
mask_token_id,
)
generations.append(generation_result)
unwrapped_model.train()
return generations
def _sample_sequences_from_dataloader(
dataloader: Any, num_samples: int, max_length: int, device: torch.device
) -> List[torch.Tensor]:
"""Sample sequences from validation dataloader."""
sampled_sequences = []
sample_count = 0
# Add randomness by skipping a random number of batches
skip_batches = torch.randint(0, 6, (1,)).item()
batch_count = 0
for batch in dataloader:
# Skip some batches for variety
if batch_count < skip_batches:
batch_count += 1
continue
if sample_count >= num_samples:
break
batch_count += 1
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask")
# Randomly sample from sequences in this batch
batch_indices = torch.randperm(input_ids.size(0)).tolist()
for i in batch_indices:
if sample_count >= num_samples:
break
# Get actual sequence length (non-padded)
if attention_mask is not None:
seq_len = attention_mask[i].sum().item()
else:
seq_len = input_ids.size(1)
# Limit sequence length to max_length
actual_length = min(seq_len, max_length)
if actual_length < 10: # Skip very short sequences
continue
# Extract the sequence
sequence = input_ids[i][:actual_length].unsqueeze(0).to(device)
sampled_sequences.append(sequence)
sample_count += 1
return sampled_sequences
def _generate(
model: torch.nn.Module,
tokenizer: Any,
original_sequence: torch.Tensor,
num_diffusion_steps: int,
temperature: float,
mask_token_id: int,
) -> dict:
"""Generate a single sample using reverse diffusion."""
# Get original text for comparison
original_text = tokenizer.decode(
original_sequence[0].cpu(), skip_special_tokens=True
)
# Apply custom masking with random ratio (10% to 70%)
total_tokens = original_sequence.size(1)
min_ratio, max_ratio = 0.1, 0.7
target_mask_ratio = torch.rand(1).item() * (max_ratio - min_ratio) + min_ratio
target_masked_tokens = int(total_tokens * target_mask_ratio)
# Create random mask indices
mask_positions = torch.randperm(total_tokens)[:target_masked_tokens]
masked_indices = torch.zeros(
1, total_tokens, dtype=torch.bool, device=original_sequence.device
)
masked_indices[0, mask_positions] = True
# Create masked sequence
masked_sequence = original_sequence.clone()
masked_sequence[masked_indices] = mask_token_id
# Calculate actual mask ratio
masked_tokens = masked_indices.sum().item()
mask_ratio = masked_tokens / total_tokens
# Get masked text for comparison
masked_text = tokenizer.decode(masked_sequence[0].cpu(), skip_special_tokens=False)
# Clean up mask token representation
masked_text = _clean_masked_text(masked_text, tokenizer, mask_token_id)
# Run reverse diffusion process
sequence = masked_sequence.clone()
for step in range(num_diffusion_steps):
sequence = _diffusion_step(
model, sequence, step, num_diffusion_steps, temperature, mask_token_id
)
# Get final generated text
generated_text = tokenizer.decode(sequence[0].cpu(), skip_special_tokens=True)
return {
"original": original_text,
"masked": masked_text,
"generated": generated_text,
"mask_ratio": mask_ratio,
"masked_tokens": masked_tokens,
"total_tokens": total_tokens,
"formatted": (
f"Original: '{original_text}' → Masked: '{masked_text}' "
f"({mask_ratio:.1%}) → Generated: '{generated_text}'"
),
}
def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str:
"""Clean up masked text for display."""
mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False)
cleaned = masked_text.replace(mask_token_repr, "[MASK]")
if hasattr(tokenizer, "special_tokens_map"):
for token_value in tokenizer.special_tokens_map.values():
if token_value and isinstance(token_value, str):
cleaned = cleaned.replace(token_value, "")
cleaned = " ".join(cleaned.split()).strip()
return cleaned
def _diffusion_step(
model: torch.nn.Module,
sequence: torch.Tensor,
step: int,
num_diffusion_steps: int,
temperature: float,
mask_token_id: int,
) -> torch.Tensor:
"""Perform a single diffusion step with remasking."""
# Only process if there are masked tokens remaining
current_mask = sequence == mask_token_id
if not current_mask.any():
return sequence
# Create bidirectional attention mask for diffusion
batch_size, seq_len = sequence.shape
attention_mask = torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=sequence.device
)
# Forward pass
outputs = model(input_ids=sequence, attention_mask=attention_mask)
logits = outputs.logits
# Only sample at currently masked positions
if current_mask.any():
masked_logits = logits[current_mask]
# Apply temperature scaling
if temperature > 0:
scaled_logits = masked_logits / temperature
else:
scaled_logits = masked_logits
# Suppress mask token in outputs
scaled_logits[:, mask_token_id] = -float("inf")
# Sample predictions
if temperature > 0:
# Add Gumbel noise for sampling
gumbel_noise = -torch.log(
-torch.log(torch.rand_like(scaled_logits, dtype=torch.float32))
)
gumbel_logits = scaled_logits + gumbel_noise
predicted_tokens = torch.argmax(gumbel_logits, dim=-1)
else:
# Deterministic sampling when temperature is 0
predicted_tokens = torch.argmax(scaled_logits, dim=-1)
# Calculate probabilities for confidence scoring
probs = torch.softmax(scaled_logits, dim=-1)
predicted_token_probs = probs[range(len(predicted_tokens)), predicted_tokens]
# Determine how many tokens to unmask this step
remaining_masked = current_mask.sum().item()
if step == num_diffusion_steps - 1:
num_to_unmask = remaining_masked
else:
unmask_ratio = 1.0 / (num_diffusion_steps - step)
num_to_unmask = max(1, int(remaining_masked * unmask_ratio))
# Select highest confidence predictions to unmask
if num_to_unmask >= remaining_masked:
sequence[current_mask] = predicted_tokens
else:
_, top_indices = predicted_token_probs.topk(num_to_unmask)
mask_positions = torch.where(current_mask)[1]
positions_to_unmask = mask_positions[top_indices]
sequence[0, positions_to_unmask] = predicted_tokens[top_indices]
return sequence

View File

@@ -0,0 +1,41 @@
"""Diffusion LM training plugin for Axolotl."""
from peft import PeftModel
from transformers import PreTrainedModel
from axolotl.integrations.base import BasePlugin
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .trainer import DiffusionTrainer
LOG = get_logger(__name__)
class DiffusionPlugin(BasePlugin):
"""
Plugin for diffusion language model training.
This plugin enables diffusion-based training using the LLaDA approach, which uses
random masking and bidirectional attention to train language models.
"""
def __init__(self):
super().__init__()
self.cfg = None
def get_input_args(self) -> str:
"""Returns the pydantic model for LLaDA plugin arguments."""
return "axolotl.integrations.diffusion.DiffusionArgs"
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Perform actions after model is loaded."""
self.cfg = cfg
def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] | None:
"""Return custom trainer class for diffusion training."""
return DiffusionTrainer
def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer):
"""Configure trainer after creation."""
trainer.set_config(cfg)

View File

@@ -0,0 +1,336 @@
"""Custom trainer for diffusion LM training."""
from typing import Any, Literal
import torch
import torch.nn.functional as F
from torch import nn
from transformers.masking_utils import find_packed_sequence_indices
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.integrations.diffusion.utils import create_bidirectional_block_mask
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .callbacks import DiffusionGenerationCallback
LOG = get_logger(__name__)
class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
"""Custom trainer for diffusion LM training that overrides loss computation."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = None
self._special_token_ids = None
def set_config(self, config: DictDefault):
"""Set config for diffusion training."""
self.config = config
self._cache_special_token_ids()
if config.generate_samples:
generation_callback = DiffusionGenerationCallback(self)
self.add_callback(generation_callback)
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor],
return_outputs: bool = False,
num_items_in_batch: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Override compute_loss to use diffusion loss."""
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask")
labels = inputs.get("labels")
position_ids = inputs.get("position_ids")
if input_ids is None:
raise ValueError("input_ids is required for diffusion training")
loss, outputs = self._compute_diffusion_loss(
model, input_ids, attention_mask, labels, position_ids
)
if return_outputs:
return loss, outputs
return loss
def _cache_special_token_ids(self):
"""Cache special token IDs to avoid repeated tokenizer access."""
if self.processing_class is None:
self._special_token_ids = set()
return
tokenizer = self.processing_class
special_tokens = set()
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
special_tokens.add(tokenizer.bos_token_id)
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
special_tokens.add(tokenizer.eos_token_id)
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
special_tokens.add(tokenizer.pad_token_id)
self._special_token_ids = special_tokens
@torch.compile
def _forward_process(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
eps: float = 1e-3,
min_p: float = 0.0,
max_p: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward noising process. A timestep is sampled along the process, and tokens are
masked with probability determined by the configured noise schedule.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
eps: Small epsilon value for minimum masking probability.
Returns:
noisy_batch: Input with some tokens masked.
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Sample random timesteps for each sample in batch
t = torch.rand(batch_size, device=device)
# Calculate masking probability with epsilon
p_mask = min_p + (max_p - min_p) * (1 - eps) * t + eps # [batch_size]
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
# Don't mask padding tokens if attention_mask is provided
if attention_mask is not None:
valid_mask = attention_mask.bool()
p_mask = p_mask * valid_mask.float()
# Create mask to exclude special tokens
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
if self._special_token_ids:
for token_id in self._special_token_ids:
special_token_mask |= input_ids == token_id
# Create random mask based on p_mask
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
masked_indices = masked_indices & ~special_token_mask
if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool()
# For SFT data, only mask answer tokens
if labels is not None:
answer_mask = labels != -100
masked_indices = masked_indices & answer_mask
# Create masked input
mask_token_id = self.config.mask_token_id
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
return noisy_batch, masked_indices, p_mask
@torch.compile
def _create_bidirectional_attention_mask(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None
) -> torch.Tensor:
"""
Create bidirectional attention mask to override default causal masking. Handles
sample-packed sequences where different samples are identified by different
attention mask values.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len]
position_ids: Position ids [batch_size, seq_len]
Returns:
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
if attention_mask is None or not self.config.sample_packing:
return torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
)
if position_ids is None:
# Create attention mask by comparing sample IDs element-wise
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
# Tokens can attend to each other if they have the same non-zero sample ID
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
# Add head dimension: [batch_size, 1, seq_len, seq_len]
bidirectional_mask = bidirectional_mask.unsqueeze(1)
return bidirectional_mask
if self._config.flex_attention:
block_mask = create_bidirectional_block_mask(
input_ids, attention_mask, position_ids
)
else:
packed_seq_mask = find_packed_sequence_indices(position_ids)
block_mask = packed_seq_mask.unsqueeze(2) == packed_seq_mask.unsqueeze(1)
return block_mask
def _compute_diffusion_loss(
self,
model: nn.Module,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | Any]:
"""
Compute diffusion loss.
Args:
model: The model to compute loss for.
input_ids: Ground truth token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
position_ids: Position ids [batch_size, seq_len].
Returns:
loss: Cross-entropy loss.
metrics: Dictionary of metrics.
"""
# Apply forward process
noisy_batch, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self._config.eps, self._config.min_mask_ratio, self._config.max_mask_ratio
)
# Create bidirectional attention mask (optional: use causal if you want strict AR behavior)
bidirectional_mask = self._create_bidirectional_attention_mask(
input_ids, attention_mask, position_ids
)
# Forward pass
outputs = model(
input_ids=noisy_batch,
attention_mask=bidirectional_mask,
)
logits = outputs.logits # [B, L, V]
# ----- AR label shift toggle -----
use_ar_shift = False
if use_ar_shift:
# Predict token at t from logits at t-1: drop last logit step, drop first target step
logits_eff = logits[:, :-1, :]
input_ids_eff = input_ids[:, 1:]
masked_indices_eff = masked_indices[:, 1:]
p_mask_eff = p_mask[:, 1:]
labels_eff = labels[:, 1:] if labels is not None else None
else:
logits_eff = logits
input_ids_eff = input_ids
masked_indices_eff = masked_indices
p_mask_eff = p_mask
labels_eff = labels
if masked_indices_eff.sum() > 0:
valid_indices = torch.where(masked_indices_eff)
batch_indices, seq_indices = valid_indices
masked_logits = logits_eff[batch_indices, seq_indices]
masked_targets = input_ids_eff[batch_indices, seq_indices]
masked_p_mask = p_mask_eff[batch_indices, seq_indices]
# Compute cross-entropy loss without reduction
token_loss = F.cross_entropy(
masked_logits.float(), masked_targets, reduction="none"
)
if self.config.importance_weighting:
masked_p_mask = masked_p_mask.float().clamp_min(1e-6)
weighted_loss = token_loss / masked_p_mask
else:
weighted_loss = token_loss
# Final loss: sum weighted losses, normalize
if labels_eff is not None:
# For SFT data: normalize by answer length per sample
answer_mask = labels_eff != -100
answer_lengths = answer_mask.sum(dim=1).float() # [batch_size]
# Get batch indices for masked tokens
masked_batch_indices = batch_indices
# Sum losses per sample and divide by answer length
loss_per_sample = torch.zeros(
input_ids.shape[0], device=input_ids.device
)
for i in range(input_ids.shape[0]):
sample_mask = masked_batch_indices == i
if sample_mask.any():
sample_loss = weighted_loss[sample_mask].sum()
loss_per_sample[i] = sample_loss / answer_lengths[i]
loss = loss_per_sample.mean()
else:
# Original normalization for non-SFT data
loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
ce_loss = token_loss.mean()
# Compute accuracy on masked tokens
with torch.no_grad():
pred_tokens = masked_logits.argmax(dim=-1)
accuracy = (pred_tokens == masked_targets).float().mean()
else:
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
accuracy = torch.tensor(0.0, device=input_ids.device)
ce_loss = torch.tensor(0.0, device=input_ids.device)
masked_p_mask = torch.tensor(1.0, device=input_ids.device)
# Keep eff tensors around for metrics
masked_indices_eff = masked_indices
p_mask_eff = p_mask
labels_eff = labels
# Metrics (aligned to the effective tensors)
if masked_indices_eff.any():
avg_p = p_mask_eff[masked_indices_eff].float().mean().item()
num_masked = int(masked_indices_eff.sum().item())
mask_ratio = masked_indices_eff.float().mean().item()
else:
avg_p = 0.0
num_masked = 0
mask_ratio = 0.0
metrics = {
"loss": float(loss.detach()),
"accuracy": float(accuracy.detach()),
"mask_ratio": mask_ratio,
"num_masked_tokens": (num_masked, "sum"),
"avg_p_mask": avg_p,
"ce_loss": float(ce_loss.detach()),
}
# SFT-specific metrics (aligned)
if labels_eff is not None:
answer_mask = labels_eff != -100
metrics["answer_ratio"] = answer_mask.float().mean().item()
metrics["avg_answer_length"] = answer_mask.sum(dim=1).float().mean().item()
if self.config.importance_weighting:
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
train_eval: Literal["train", "eval"] = "train" if model.training else "eval"
self.store_metrics(metrics, train_eval=train_eval)
return loss, outputs

View File

@@ -0,0 +1,50 @@
import torch
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
from transformers.masking_utils import find_packed_sequence_indices, packed_sequence_mask_function
def create_bidirectional_block_mask(
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
) -> "BlockMask":
"""
Creates a bidirectional block mask for FlexAttention.
Args:
input_ids: Input token ids [batch_size, seq_len]
attention_mask: Padding mask [batch_size, seq_len]
Returns:
BlockMask for bidirectional attention with padding
"""
batch_size, seq_len = input_ids.shape
if position_ids is not None:
packed_seq_mask = find_packed_sequence_indices(position_ids)
mask_fn =packed_sequence_mask_function(packed_seq_mask, batch_size, seq_len)
elif attention_mask is None:
# If no padding mask, all positions can attend to all positions
def mask_fn(b, h, q_idx, kv_idx):
# Always return True for bidirectional attention
return True
else:
# Convert attention_mask to boolean if needed
attention_mask = attention_mask.bool()
def mask_fn(b, h, q_idx, kv_idx):
# Both query and key positions must be valid (not padding)
return attention_mask[b, q_idx] & attention_mask[b, kv_idx]
# Create the block mask
block_mask = create_block_mask(
mask_fn,
B=batch_size,
H=None, # Will be set by the attention layer
Q_LEN=seq_len,
KV_LEN=seq_len,
device=input_ids.device,
_compile=True,
)
return block_mask

View File

@@ -57,7 +57,7 @@ class SpectrumPlugin(BasePlugin):
Spectrum Plugin to automatically generate unfrozen parameters based on SNR data.
"""
base_url = "https://raw.githubusercontent.com/cognitivecomputations/spectrum/main/model_snr_results/"
base_url = "https://raw.githubusercontent.com/QuixiAI/spectrum/main/model_snr_results/"
base_path = "./model_snr_results/"
snr_file_template = "snr_results_{model_name_slug}.json"

View File

@@ -681,6 +681,23 @@ class ModelLoader:
return hf_ds_cfg
def _load_model_from_config(self) -> PreTrainedModel:
"""Load model with random initialization using from_config."""
if self.auto_model_loader in [AutoModelForCausalLM, AutoModelForVision2Seq]:
return self.auto_model_loader.from_config(config=self.model_config)
return self.auto_model_loader(config=self.model_config)
def _load_model_from_pretrained(self, model_loader_class=None) -> PreTrainedModel:
"""Load model from pretrained weights."""
loader = model_loader_class or self.auto_model_loader
kwargs = {
**self.model_kwargs,
"config": self.model_config,
"trust_remote_code": self.cfg.trust_remote_code or False,
**self.model_kwargs,
}
return loader.from_pretrained(self.base_model, **kwargs)
def _build_model(self) -> bool:
"""Load model, with load strategy depending on config."""
skip_move_to_device = False
@@ -695,7 +712,8 @@ class ModelLoader:
if self.is_fsdp_enabled:
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
skip_move_to_device = True
# Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map
# Don't delete device_map for QLoRA + FSDP - it was set correctly in
# _set_device_map
if (
"device_map" in self.model_kwargs
and not self.is_qlora_and_fsdp_enabled
@@ -724,6 +742,11 @@ class ModelLoader:
or self.cfg.qlora_sharded_model_loading
)
):
if self.cfg.reinit_weights:
LOG.warning(
"reinit_weights is not supported with sharded quantized loading. "
"Loading from pretrained weights instead."
)
quant_storage = self.cfg.torch_dtype
quantization_config = getattr(
self.model_config, "quantization_config", None
@@ -739,33 +762,12 @@ class ModelLoader:
quantization_config=quantization_config,
)
skip_move_to_device = True
elif (
self.model_config.model_type in ["llama", "llama4"]
and not self.cfg.trust_remote_code
and not self.cfg.gptq
):
# Please don't remove underscore binding without reading the fn docstring.
_ = self._configure_zero3_memory_efficient_loading()
# Load model with random initialization if specified
if self.cfg.random_init_weights:
# AutoModel classes support the from_config method
if self.auto_model_loader in [
AutoModelForCausalLM,
AutoModelForVision2Seq,
]:
self.model = self.auto_model_loader.from_config(
config=self.model_config,
)
else:
self.model = self.auto_model_loader(config=self.model_config)
else:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
**self.model_kwargs,
)
elif self.model_type == "MambaLMHeadModel":
if self.cfg.reinit_weights:
LOG.warning(
"reinit_weights is not supported with MambaLMHeadModel. "
"Loading from pretrained weights instead."
)
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
@@ -778,41 +780,27 @@ class ModelLoader:
self.base_model,
**self.model_kwargs,
)
elif (
self.model_type
and self.model_type != "AutoModelForCausalLM"
and not self.cfg.trust_remote_code
):
if self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
self.model = getattr(transformers, self.model_type).from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
elif self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
# Please don't remove underscore binding without reading the fn docstring.
# Please don't remove underscore binding without reading the fn docstring
_ = self._configure_zero3_memory_efficient_loading()
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
if (
self.model_type
and self.model_type != "AutoModelForCausalLM"
and not self.cfg.trust_remote_code
and not self.cfg.gptq
):
# Use model type from transformers
model_loader_class = getattr(transformers, self.model_type)
else:
# Use auto model loader (handles gptq and default cases)
model_loader_class = self.auto_model_loader
if self.cfg.reinit_weights:
self.model = self._load_model_from_config()
else:
self.model = self._load_model_from_pretrained(model_loader_class)
if is_deepspeed_zero3_enabled():
skip_move_to_device = True

View File

@@ -187,7 +187,7 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
if module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to(

View File

@@ -72,10 +72,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
builder_kwargs["message_field_training"] = message_field_training
chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml"))
def format_message(x):
return x
format_message = (
lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment
)
if chat_template == "chatml":
from axolotl.core.chat.format.chatml import format_message # noqa F811
if chat_template.startswith("llama3"):

View File

@@ -75,7 +75,7 @@ class PromptTokenizingStrategy(abc.ABC):
) -> BatchEncoding:
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
if not prompt:
LOG.warning("Empty text requested for tokenization.")
LOG.warning_once("Empty text requested for tokenization.")
return empty
result = self.tokenizer(

View File

@@ -253,9 +253,7 @@ def save_trained_model(
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return
if ( # pylint: disable=too-many-nested-blocks
trainer.is_fsdp_enabled or cfg.fsdp_config
):
if trainer.is_fsdp_enabled or cfg.fsdp_config:
if cfg.fsdp_config or cfg.fsdp:
if cfg.fsdp_config.final_state_dict_type:
state_dict_type = cfg.fsdp_config.final_state_dict_type
@@ -287,8 +285,6 @@ def save_trained_model(
if trainer.accelerator.is_main_process:
# move all files in merged_path to cfg.output_dir
for merged_file in Path(merged_path).iterdir():
if (Path(cfg.output_dir) / merged_file.name).exists():
(Path(cfg.output_dir) / merged_file.name).unlink()
shutil.move(str(merged_file), cfg.output_dir)
shutil.rmtree(merged_path) # remove what should be an empty dir
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207

View File

@@ -1,216 +0,0 @@
# Axolotl TUI (Terminal User Interface)
A comprehensive Terminal User Interface for Axolotl, providing an interactive way to manage configurations, training jobs, datasets, models, and system monitoring.
## Features
### 🏠 Main Dashboard
- **Welcome Screen**: Central hub with quick access to all features
- **Keyboard Navigation**: Efficient navigation with keyboard shortcuts
- **Screen Management**: Easy switching between different functional areas
### 📝 Configuration Management
- **YAML Editor**: Syntax-highlighted editor for Axolotl configurations
- **Real-time Validation**: Instant config validation with detailed error reporting
- **File Browser**: Navigate and select configuration files
- **Template Loading**: Load example configurations
- **Remote Config Support**: Load configurations from URLs
**Key Shortcuts:**
- `Ctrl+N`: New configuration
- `Ctrl+S`: Save configuration
- `Ctrl+V`: Validate configuration
- `Ctrl+E`: Toggle edit mode
### 🚀 Training Management
- **Job Launcher**: Start training with different launchers (accelerate, torchrun)
- **Real-time Monitoring**: Live training progress and metrics
- **Loss Visualization**: Sparkline charts for loss curves
- **Job Control**: Start, stop, resume, and manage multiple training jobs
- **Log Streaming**: Real-time log viewing and filtering
**Key Shortcuts:**
- `Ctrl+T`: New training job
- `Ctrl+R`: Resume training
- `Ctrl+X`: Stop training
- `R`: Refresh status
### 📊 Dataset Management
- **Dataset Browser**: Explore local and remote datasets
- **Preview & Statistics**: View dataset samples and metadata
- **Preprocessing**: Run dataset preprocessing with progress tracking
- **HuggingFace Integration**: Download and manage HF datasets
- **Format Detection**: Automatic dataset format recognition
**Key Shortcuts:**
- `Ctrl+P`: Preprocess dataset
- `Ctrl+V`: Preview dataset
- `Ctrl+I`: Dataset information
- `R`: Refresh dataset list
### 🤖 Model Management
- **Model Discovery**: Automatically find trained models
- **LoRA Operations**: Merge LoRA adapters with base models
- **Quantization**: Quantize models for deployment
- **Evaluation**: Run model evaluation benchmarks
- **Storage Info**: View model sizes and storage details
**Key Shortcuts:**
- `Ctrl+M`: Merge LoRA
- `Ctrl+Q`: Quantize model
- `Ctrl+E`: Evaluate model
- `R`: Refresh model list
### 💬 Inference & Testing
- **Interactive Chat**: Chat interface for model testing
- **Parameter Tuning**: Adjust inference parameters (temperature, top-p, max tokens)
- **Model Loading**: Load and switch between different models
- **Chat History**: Save and load conversation history
- **Gradio Integration**: Launch Gradio web interface
**Key Shortcuts:**
- `Ctrl+Enter`: Send message
- `Ctrl+C`: Clear chat
- `Ctrl+L`: Load model
- `Ctrl+S`: Save chat
### 📈 System Monitoring
- **Resource Monitoring**: Real-time CPU, GPU, and memory usage
- **Process Management**: View and manage running processes
- **Performance Graphs**: Historical usage charts with sparklines
- **GPU Information**: Detailed GPU status and memory usage
- **Temperature Monitoring**: System temperature tracking
**Key Shortcuts:**
- `R`: Refresh metrics
- `Ctrl+K`: Kill selected process
## Installation
### Dependencies
```bash
pip install textual==1.0.0 rich==14.1.0
```
### Launch TUI
```bash
# From command line
python -m axolotl.cli.main tui
# From Python code
from axolotl.tui.app import run
run()
```
## Architecture
### Screen Structure
```
AxolotlTUI (Main App)
├── WelcomeScreen (Dashboard)
├── ConfigScreen (Configuration Management)
├── TrainingScreen (Training Management)
├── DatasetScreen (Dataset Management)
├── ModelScreen (Model Management)
├── InferenceScreen (Inference & Testing)
└── MonitorScreen (System Monitoring)
```
### Key Components
- **BaseScreen**: Common functionality for all screens
- **Screen Navigation**: Stack-based screen management
- **Event Handling**: Reactive UI updates
- **Background Tasks**: Non-blocking operations
- **State Management**: Shared application state
### Integration Points
- **CLI Commands**: Seamless integration with existing axolotl CLI
- **Configuration System**: Uses axolotl's native config loading
- **Training Pipeline**: Integrates with axolotl training functions
- **Model Loading**: Compatible with axolotl model management
## Usage Examples
### 1. Creating a New Configuration
1. Launch TUI: `python -m axolotl.cli.main tui`
2. Select "Configuration Management" or press `C`
3. Press `Ctrl+N` for new configuration
4. Edit the template configuration
5. Press `Ctrl+V` to validate
6. Press `Ctrl+S` to save
### 2. Starting a Training Job
1. Navigate to "Training Management" or press `T`
2. Press `Ctrl+T` for new training job
3. Select configuration file and launcher
4. Monitor progress in real-time
5. View loss curves and logs
### 3. Interactive Model Testing
1. Go to "Inference & Testing" or press `I`
2. Load a trained model with `Ctrl+L`
3. Adjust inference parameters as needed
4. Start chatting with the model
5. Save conversation with `Ctrl+S`
## Navigation
### Global Shortcuts
- `Ctrl+Q`: Quit application
- `Escape`: Go back/close current screen
- `Tab`: Navigate between UI elements
- `Enter`: Select/activate element
- `Space`: Toggle switches/checkboxes
### Screen Shortcuts
Each screen has specific shortcuts displayed in the footer. Common patterns:
- `Ctrl+[Letter]`: Primary actions
- `R`: Refresh/reload
- `F1-F12`: Function keys for advanced features
## Customization
### Themes
The TUI uses Textual's theming system and can be customized by modifying the CSS in each screen class.
### Adding New Screens
1. Create a new screen class inheriting from `BaseScreen`
2. Implement the `compose()` method for UI layout
3. Add event handlers for user interactions
4. Register the screen in the main app navigation
### Extending Functionality
- Add new widgets to existing screens
- Implement custom data visualization
- Integrate with external tools and APIs
- Add new keyboard shortcuts
## Troubleshooting
### Common Issues
1. **Import Errors**: Ensure textual and rich are installed
2. **Permission Errors**: Check file system permissions for config directories
3. **GPU Monitoring**: Install pynvml for GPU monitoring features
4. **Config Validation**: Ensure axolotl dependencies are properly installed
### Debug Mode
Launch with debug logging:
```bash
TEXTUAL_LOG=DEBUG python -m axolotl.cli.main tui
```
### Performance
- Use `Ctrl+\` to open Textual's debug console
- Monitor memory usage with the system monitor
- Disable auto-refresh for better performance on slower systems
## Contributing
The TUI is designed to be extensible. Contributions are welcome for:
- New screen implementations
- Enhanced visualizations
- Better keyboard navigation
- Additional integrations
- Performance improvements
See the main Axolotl repository for contribution guidelines.

View File

@@ -1 +0,0 @@
"""Axolotl Terminal User Interface (TUI)."""

View File

@@ -1,180 +0,0 @@
"""Main TUI application for Axolotl."""
from textual import on
from textual.app import App, ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.screen import Screen
from textual.widgets import Button, Footer, Header, Static
from axolotl.tui.screens.config import ConfigScreen
from axolotl.tui.screens.datasets import DatasetScreen
from axolotl.tui.screens.inference import InferenceScreen
from axolotl.tui.screens.models import ModelScreen
from axolotl.tui.screens.monitor import MonitorScreen
from axolotl.tui.screens.training import TrainingScreen
class WelcomeScreen(Screen):
"""Welcome screen with main menu."""
BINDINGS = [
Binding("q", "quit", "Quit"),
Binding("c", "config", "Configuration"),
Binding("t", "training", "Training"),
Binding("d", "datasets", "Datasets"),
Binding("m", "models", "Models"),
Binding("i", "inference", "Inference"),
Binding("s", "monitor", "System Monitor"),
]
def compose(self) -> ComposeResult:
"""Compose the welcome screen."""
yield Header()
yield Container(
Static("🦾 Axolotl TUI", classes="title"),
Static(
"A Terminal User Interface for fine-tuning LLMs", classes="subtitle"
),
Container(
Button("Configuration Management [C]", id="config", variant="primary"),
Button("Training Management [T]", id="training", variant="primary"),
Button("Dataset Management [D]", id="datasets", variant="primary"),
Button("Model Management [M]", id="models", variant="primary"),
Button("Inference & Testing [I]", id="inference", variant="primary"),
Button("System Monitor [S]", id="monitor", variant="primary"),
classes="menu-container",
),
classes="welcome-container",
)
yield Footer()
def action_quit(self) -> None:
"""Quit the application."""
self.app.exit()
def action_config(self) -> None:
"""Navigate to config screen."""
self.app.push_screen(ConfigScreen())
def action_training(self) -> None:
"""Navigate to training screen."""
self.app.push_screen(TrainingScreen())
def action_datasets(self) -> None:
"""Navigate to datasets screen."""
self.app.push_screen(DatasetScreen())
def action_models(self) -> None:
"""Navigate to models screen."""
self.app.push_screen(ModelScreen())
def action_inference(self) -> None:
"""Navigate to inference screen."""
self.app.push_screen(InferenceScreen())
def action_monitor(self) -> None:
"""Navigate to monitor screen."""
self.app.push_screen(MonitorScreen())
@on(Button.Pressed, "#config")
def on_config_pressed(self) -> None:
"""Handle config button press."""
self.action_config()
@on(Button.Pressed, "#training")
def on_training_pressed(self) -> None:
"""Handle training button press."""
self.action_training()
@on(Button.Pressed, "#datasets")
def on_datasets_pressed(self) -> None:
"""Handle datasets button press."""
self.action_datasets()
@on(Button.Pressed, "#models")
def on_models_pressed(self) -> None:
"""Handle models button press."""
self.action_models()
@on(Button.Pressed, "#inference")
def on_inference_pressed(self) -> None:
"""Handle inference button press."""
self.action_inference()
@on(Button.Pressed, "#monitor")
def on_monitor_pressed(self) -> None:
"""Handle monitor button press."""
self.action_monitor()
class AxolotlTUI(App):
"""Main Axolotl TUI Application."""
CSS = """
.title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.subtitle {
text-align: center;
padding: 1;
color: $text-muted;
}
.welcome-container {
align: center middle;
height: 100%;
width: 100%;
}
.menu-container {
layout: vertical;
align: center middle;
padding: 2;
width: auto;
height: auto;
}
.menu-container Button {
width: 35;
margin: 1;
}
WelcomeScreen {
align: center middle;
}
"""
BINDINGS = [
Binding("ctrl+q", "quit", "Quit", priority=True),
Binding("escape", "back", "Back", priority=True),
]
def on_mount(self) -> None:
"""Called when the app is mounted."""
self.title = "Axolotl TUI"
self.sub_title = "Fine-tuning LLMs made easy"
self.push_screen(WelcomeScreen())
def action_quit(self) -> None:
"""Quit the application."""
self.exit()
def action_back(self) -> None:
"""Go back to previous screen."""
if len(self.screen_stack) > 1:
self.pop_screen()
def run():
"""Run the Axolotl TUI application."""
app = AxolotlTUI()
app.run()
if __name__ == "__main__":
run()

View File

@@ -1 +0,0 @@
"""TUI dialogs for Axolotl."""

View File

@@ -1,112 +0,0 @@
"""Training dialogs for Axolotl TUI."""
from pathlib import Path
from textual import on
from textual.app import ComposeResult
from textual.containers import Container
from textual.screen import ModalScreen
from textual.widgets import Button, Input, Label, Select, Static
class NewTrainingDialog(ModalScreen):
"""Dialog for starting a new training job."""
CSS = """
NewTrainingDialog {
align: center middle;
}
.dialog-container {
background: $surface;
border: thick $primary;
padding: 2;
width: 60;
height: auto;
}
.dialog-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.form-field {
margin: 1 0;
}
.form-label {
margin: 0 0 1 0;
color: $text-muted;
}
.button-container {
layout: horizontal;
align: center middle;
margin: 2 0 0 0;
}
.button-container Button {
margin: 0 1;
}
"""
def compose(self) -> ComposeResult:
"""Compose the dialog."""
yield Container(
Static("Start New Training Job", classes="dialog-title"),
Container(
Label("Configuration File:", classes="form-label"),
Input(
placeholder="Path to config YAML file",
id="config-path",
value="/workspace/configs/",
),
classes="form-field",
),
Container(
Label("Launcher:", classes="form-label"),
Select(
[
("accelerate", "Accelerate (Recommended)"),
("torchrun", "TorchRun"),
("deepspeed", "DeepSpeed"),
],
id="launcher",
value="accelerate",
),
classes="form-field",
),
Container(
Button("Start Training", variant="primary", id="start"),
Button("Cancel", variant="default", id="cancel"),
classes="button-container",
),
classes="dialog-container",
)
@on(Button.Pressed, "#start")
def handle_start(self) -> None:
"""Handle start button press."""
config_input = self.query_one("#config-path", Input)
launcher_select = self.query_one("#launcher", Select)
config_path = config_input.value.strip()
if not config_path:
return
if not Path(config_path).exists():
return
result = {
"config_path": config_path,
"launcher": launcher_select.value,
}
self.dismiss(result)
@on(Button.Pressed, "#cancel")
def handle_cancel(self) -> None:
"""Handle cancel button press."""
self.dismiss(None)

View File

@@ -1 +0,0 @@
"""TUI screens for Axolotl."""

View File

@@ -1,50 +0,0 @@
"""Base screen class for Axolotl TUI screens."""
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.screen import Screen
from textual.widgets import Footer, Header, Static
class BaseScreen(Screen):
"""Base class for all Axolotl TUI screens."""
BINDINGS = [
Binding("escape", "back", "Back"),
Binding("q", "quit", "Quit"),
]
def __init__(self, title: str = "Axolotl", subtitle: str = ""):
"""Initialize the base screen.
Args:
title: The screen title
subtitle: Optional subtitle for the screen
"""
super().__init__()
self.screen_title = title
self.screen_subtitle = subtitle
def compose(self) -> ComposeResult:
"""Compose the base screen layout."""
yield Header()
yield Container(
Static(f"🦾 {self.screen_title}", classes="screen-title"),
(
Static(self.screen_subtitle, classes="screen-subtitle")
if self.screen_subtitle
else Static("")
),
Container(id="content"),
id="main-container",
)
yield Footer()
def action_back(self) -> None:
"""Go back to previous screen."""
self.app.pop_screen()
def action_quit(self) -> None:
"""Quit the application."""
self.app.exit()

View File

@@ -1,376 +0,0 @@
"""Configuration management screen for Axolotl TUI."""
import os
from pathlib import Path
from typing import Optional
import yaml
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.reactive import reactive
from textual.widgets import (
Button,
DirectoryTree,
Footer,
Header,
Label,
Log,
Static,
TextArea,
)
from axolotl.tui.screens.base import BaseScreen
class ConfigScreen(BaseScreen):
"""Configuration management screen."""
BINDINGS = [
Binding("ctrl+n", "new_config", "New Config"),
Binding("ctrl+o", "open_config", "Open Config"),
Binding("ctrl+s", "save_config", "Save Config"),
Binding("ctrl+v", "validate_config", "Validate"),
Binding("ctrl+e", "edit_mode", "Toggle Edit Mode"),
]
CSS = """
.config-container {
layout: horizontal;
height: 100%;
}
.file-browser {
width: 30%;
border: solid $primary;
padding: 1;
margin: 1;
}
.config-editor {
width: 70%;
border: solid $secondary;
padding: 1;
margin: 1;
}
.config-form {
height: 80%;
}
.config-actions {
layout: horizontal;
height: 3;
align: center middle;
padding: 1;
}
.config-actions Button {
margin: 0 1;
}
TextArea {
height: 100%;
}
.validation-log {
height: 20%;
border: solid $warning;
padding: 1;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
"""
def __init__(self):
"""Initialize the config screen."""
super().__init__(
title="Configuration Management",
subtitle="Create, edit, and validate Axolotl configurations",
)
self.current_config_path: Optional[Path] = None
self.edit_mode = reactive(False)
self.config_data = {}
def compose(self) -> ComposeResult:
"""Compose the config screen layout."""
yield Header()
yield Container(
Static("🦾 Configuration Management", classes="screen-title"),
Static(
"Create, edit, and validate Axolotl configurations",
classes="screen-subtitle",
),
Container(
Container(
Label("Config Files"),
DirectoryTree(
(
Path("/workspace/configs")
if Path("/workspace/configs").exists()
else Path.cwd()
),
id="config-tree",
),
classes="file-browser",
),
Container(
Container(
TextArea(
"",
language="yaml",
theme="monokai",
id="config-editor",
read_only=True,
),
classes="config-form",
),
Container(
Button("New", id="new-config", variant="primary"),
Button("Open", id="open-config", variant="primary"),
Button("Save", id="save-config", variant="success"),
Button("Validate", id="validate-config", variant="warning"),
Button("Edit Mode", id="toggle-edit", variant="default"),
Button("Load Example", id="load-example", variant="default"),
classes="config-actions",
),
Container(
Log(id="validation-log"),
classes="validation-log",
),
classes="config-editor",
),
classes="config-container",
),
id="content",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
tree = self.query_one("#config-tree", DirectoryTree)
tree.show_root = False
tree.guide_depth = 3
log = self.query_one("#validation-log", Log)
log.write_line("Ready to load configuration files...")
@on(DirectoryTree.FileSelected)
def handle_file_selected(self, event: DirectoryTree.FileSelected) -> None:
"""Handle file selection from the directory tree."""
if event.path.suffix in [".yaml", ".yml"]:
self.load_config_file(event.path)
def load_config_file(self, path: Path) -> None:
"""Load a configuration file."""
self.current_config_path = path
try:
with open(path, "r") as f:
content = f.read()
self.config_data = yaml.safe_load(content)
editor = self.query_one("#config-editor", TextArea)
editor.load_text(content)
log = self.query_one("#validation-log", Log)
log.clear()
log.write_line(f"✅ Loaded: {path.name}")
except Exception as e:
log = self.query_one("#validation-log", Log)
log.write_line(f"❌ Error loading {path.name}: {str(e)}")
@on(Button.Pressed, "#new-config")
def handle_new_config(self) -> None:
"""Create a new configuration."""
template = """# Axolotl Configuration
base_model:
model_type:
tokenizer_type:
# Dataset Configuration
datasets:
- path:
type:
# Training Configuration
output_dir: ./outputs
num_epochs: 3
micro_batch_size: 1
gradient_accumulation_steps: 4
learning_rate: 0.00002
warmup_steps: 100
eval_steps: 100
save_steps: 500
# LoRA Configuration (optional)
adapter: lora
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
# Training optimizations
gradient_checkpointing: true
flash_attention: true
bf16: auto
tf32: true
# Logging
logging_steps: 10
wandb_project:
wandb_entity:
"""
editor = self.query_one("#config-editor", TextArea)
editor.load_text(template)
editor.read_only = False
self.edit_mode = True
self.update_edit_button()
log = self.query_one("#validation-log", Log)
log.clear()
log.write_line("📝 New configuration created. Edit and save when ready.")
@on(Button.Pressed, "#save-config")
def handle_save_config(self) -> None:
"""Save the current configuration."""
editor = self.query_one("#config-editor", TextArea)
content = editor.text
if not content.strip():
log = self.query_one("#validation-log", Log)
log.write_line("⚠️ Cannot save empty configuration")
return
if not self.current_config_path:
default_path = Path("/workspace/configs/new_config.yaml")
default_path.parent.mkdir(parents=True, exist_ok=True)
self.current_config_path = default_path
try:
with open(self.current_config_path, "w") as f:
f.write(content)
log = self.query_one("#validation-log", Log)
log.write_line(f"💾 Saved: {self.current_config_path.name}")
except Exception as e:
log = self.query_one("#validation-log", Log)
log.write_line(f"❌ Error saving: {str(e)}")
@on(Button.Pressed, "#validate-config")
@work(thread=True)
async def handle_validate_config(self) -> None:
"""Validate the current configuration."""
editor = self.query_one("#config-editor", TextArea)
content = editor.text
if not content.strip():
log = self.query_one("#validation-log", Log)
log.write_line("⚠️ No configuration to validate")
return
log = self.query_one("#validation-log", Log)
log.clear()
log.write_line("🔍 Validating configuration...")
try:
import tempfile
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
f.write(content)
temp_path = f.name
from argparse import Namespace
from axolotl.cli.config import check_user_config
args = Namespace(
config=temp_path,
debug=False,
debug_text_only=False,
debug_num_examples=5,
accelerate_config=None,
multi_gpu=False,
)
check_user_config(args)
log.write_line("✅ Configuration is valid!")
os.unlink(temp_path)
except Exception as e:
log.write_line(f"❌ Validation failed: {str(e)}")
if "temp_path" in locals():
os.unlink(temp_path)
@on(Button.Pressed, "#toggle-edit")
def handle_toggle_edit(self) -> None:
"""Toggle edit mode for the configuration."""
editor = self.query_one("#config-editor", TextArea)
self.edit_mode = not self.edit_mode
editor.read_only = not self.edit_mode
self.update_edit_button()
log = self.query_one("#validation-log", Log)
if self.edit_mode:
log.write_line("✏️ Edit mode enabled")
else:
log.write_line("👁️ View mode enabled")
@on(Button.Pressed, "#load-example")
async def handle_load_example(self) -> None:
"""Load an example configuration."""
examples_dir = Path("/workspace/axolotl/examples")
if not examples_dir.exists():
log = self.query_one("#validation-log", Log)
log.write_line("⚠️ Examples directory not found")
return
yaml_files = list(examples_dir.glob("**/*.yml")) + list(
examples_dir.glob("**/*.yaml")
)
if yaml_files:
self.load_config_file(yaml_files[0])
log = self.query_one("#validation-log", Log)
log.write_line(f"📚 Loaded example: {yaml_files[0].name}")
def update_edit_button(self) -> None:
"""Update the edit button appearance."""
button = self.query_one("#toggle-edit", Button)
if self.edit_mode:
button.variant = "warning"
button.label = "Edit Mode: ON"
else:
button.variant = "default"
button.label = "Edit Mode: OFF"
def action_new_config(self) -> None:
"""Create a new configuration."""
self.handle_new_config()
def action_save_config(self) -> None:
"""Save the current configuration."""
self.handle_save_config()
def action_validate_config(self) -> None:
"""Validate the current configuration."""
self.handle_validate_config()
def action_edit_mode(self) -> None:
"""Toggle edit mode."""
self.handle_toggle_edit()

View File

@@ -1,440 +0,0 @@
"""Dataset management screen for Axolotl TUI."""
import json
from pathlib import Path
from typing import Dict, Optional
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Label,
Log,
ProgressBar,
Static,
TextArea,
)
from axolotl.tui.screens.base import BaseScreen
class DatasetScreen(BaseScreen):
"""Dataset management screen."""
BINDINGS = [
Binding("ctrl+p", "preprocess", "Preprocess"),
Binding("ctrl+v", "preview", "Preview"),
Binding("ctrl+i", "info", "Info"),
Binding("r", "refresh", "Refresh"),
]
CSS = """
.dataset-container {
layout: horizontal;
height: 100%;
}
.dataset-list {
width: 40%;
border: solid $primary;
padding: 1;
margin: 1;
}
.dataset-details {
width: 60%;
border: solid $secondary;
padding: 1;
margin: 1;
}
.dataset-actions {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
}
.dataset-actions Button {
margin: 0 1;
}
DataTable {
height: 100%;
}
.preview-container {
height: 100%;
border: solid $primary;
padding: 1;
}
TextArea {
height: 100%;
}
.stats-container {
layout: vertical;
padding: 1;
}
.stat-row {
layout: horizontal;
padding: 0 0 1 0;
}
.stat-label {
width: 50%;
color: $text-muted;
}
.stat-value {
width: 50%;
text-align: right;
text-style: bold;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
.progress-container {
padding: 1;
border: solid $warning;
margin: 1;
}
"""
def __init__(self):
"""Initialize the dataset screen."""
super().__init__(
title="Dataset Management",
subtitle="Browse, preview, and preprocess datasets",
)
self.datasets: Dict[str, Dict] = {}
self.selected_dataset: Optional[str] = None
self.preprocessing_active = False
def compose(self) -> ComposeResult:
"""Compose the dataset screen layout."""
yield Header()
yield Container(
Static("🦾 Dataset Management", classes="screen-title"),
Static(
"Browse, preview, and preprocess datasets", classes="screen-subtitle"
),
Container(
Container(
Label("Available Datasets"),
DataTable(id="dataset-table"),
Container(
Button("Load Dataset", id="load-dataset", variant="primary"),
Button("Preprocess", id="preprocess", variant="success"),
Button("Download", id="download", variant="default"),
Button("Refresh", id="refresh", variant="default"),
classes="dataset-actions",
),
classes="dataset-list",
),
Container(
TextArea("", id="dataset-preview", read_only=True),
Container(
Static("Dataset Name:", classes="stat-label"),
Static("-", id="stat-name", classes="stat-value"),
Static("Type:", classes="stat-label"),
Static("-", id="stat-type", classes="stat-value"),
Static("Size:", classes="stat-label"),
Static("-", id="stat-size", classes="stat-value"),
Static("Samples:", classes="stat-label"),
Static("-", id="stat-samples", classes="stat-value"),
Static("Features:", classes="stat-label"),
Static("-", id="stat-features", classes="stat-value"),
Static("Format:", classes="stat-label"),
Static("-", id="stat-format", classes="stat-value"),
Static("Preprocessed:", classes="stat-label"),
Static("-", id="stat-preprocessed", classes="stat-value"),
),
Log(id="processing-log"),
ProgressBar(total=100, id="preprocessing-progress"),
classes="dataset-details",
),
classes="dataset-container",
),
id="content",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.setup_dataset_table()
self.load_datasets()
log = self.query_one("#processing-log", Log)
log.write_line("Dataset manager ready.")
def setup_dataset_table(self) -> None:
"""Setup the dataset table."""
table = self.query_one("#dataset-table", DataTable)
table.add_columns("Name", "Type", "Size", "Status")
table.cursor_type = "row"
table.zebra_stripes = True
@work(thread=True)
async def load_datasets(self) -> None:
"""Load available datasets."""
# Check for local datasets
datasets_dir = Path("/workspace/datasets")
if datasets_dir.exists():
for dataset_path in datasets_dir.glob("*"):
if dataset_path.is_dir():
self.datasets[dataset_path.name] = {
"name": dataset_path.name,
"path": str(dataset_path),
"type": "local",
"size": self.get_dir_size(dataset_path),
"status": "available",
}
# Check for HuggingFace datasets in configs
configs_dir = Path("/workspace/configs")
if configs_dir.exists():
for config_file in configs_dir.glob("*.yaml"):
try:
import yaml
with open(config_file) as f:
config = yaml.safe_load(f)
if "datasets" in config:
for ds in config.get("datasets", []):
if "path" in ds:
ds_name = ds["path"].split("/")[-1]
self.datasets[ds_name] = {
"name": ds_name,
"path": ds["path"],
"type": ds.get("type", "huggingface"),
"size": "Unknown",
"status": "remote",
}
except Exception:
pass
self.refresh_dataset_table()
def get_dir_size(self, path: Path) -> str:
"""Get human-readable directory size."""
total_size = sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
for unit in ["B", "KB", "MB", "GB"]:
if total_size < 1024.0:
return f"{total_size:.2f} {unit}"
total_size /= 1024.0
return f"{total_size:.2f} TB"
def refresh_dataset_table(self) -> None:
"""Refresh the dataset table."""
table = self.query_one("#dataset-table", DataTable)
table.clear()
for name, info in self.datasets.items():
table.add_row(
name[:30],
info["type"],
info["size"],
info["status"],
)
@on(DataTable.RowSelected)
def handle_dataset_selected(self, event: DataTable.RowSelected) -> None:
"""Handle dataset selection from table."""
if event.cursor_row >= 0:
dataset_names = list(self.datasets.keys())
if event.cursor_row < len(dataset_names):
self.selected_dataset = dataset_names[event.cursor_row]
self.load_dataset_preview()
self.update_dataset_stats()
@work(thread=True)
async def load_dataset_preview(self) -> None:
"""Load preview of selected dataset."""
if not self.selected_dataset:
return
dataset_info = self.datasets[self.selected_dataset]
preview_text = ""
try:
if dataset_info["type"] == "local" and Path(dataset_info["path"]).exists():
# Load first few samples from local dataset
sample_files = list(Path(dataset_info["path"]).glob("*.json"))[:3]
samples = []
for sample_file in sample_files:
with open(sample_file) as f:
samples.append(json.load(f))
preview_text = json.dumps(samples, indent=2)
else:
# Show dataset info for remote datasets
preview_text = json.dumps(dataset_info, indent=2)
except Exception as e:
preview_text = f"Error loading preview: {str(e)}"
preview = self.query_one("#dataset-preview", TextArea)
preview.load_text(preview_text)
def update_dataset_stats(self) -> None:
"""Update dataset statistics display."""
if not self.selected_dataset:
return
info = self.datasets[self.selected_dataset]
self.query_one("#stat-name", Static).update(info["name"])
self.query_one("#stat-type", Static).update(info["type"])
self.query_one("#stat-size", Static).update(info["size"])
self.query_one("#stat-samples", Static).update("N/A")
self.query_one("#stat-features", Static).update("N/A")
self.query_one("#stat-format", Static).update("JSON")
self.query_one("#stat-preprocessed", Static).update("No")
@on(Button.Pressed, "#preprocess")
@work(thread=True)
async def handle_preprocess(self) -> None:
"""Preprocess selected dataset."""
if not self.selected_dataset or self.preprocessing_active:
return
self.preprocessing_active = True
dataset_info = self.datasets[self.selected_dataset]
log = self.query_one("#processing-log", Log)
log.clear()
log.write_line(f"🔄 Starting preprocessing for {self.selected_dataset}...")
progress = self.query_one("#preprocessing-progress", ProgressBar)
progress.update(progress=0)
try:
import subprocess
import tempfile
# Create a temporary config for preprocessing
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
config = {
"datasets": [
{
"path": dataset_info["path"],
"type": dataset_info.get("type", "alpaca"),
}
],
"output_dir": f"/tmp/preprocessed_{self.selected_dataset}",
}
import yaml
yaml.dump(config, f)
temp_config = f.name
# Run preprocessing
cmd = ["python", "-m", "axolotl.cli.preprocess", temp_config]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
# Monitor progress
for line in process.stdout:
log.write_line(line.strip())
# Update progress bar based on output
if "Processing" in line:
progress.advance(10)
process.wait()
if process.returncode == 0:
log.write_line("✅ Preprocessing completed successfully!")
dataset_info["status"] = "preprocessed"
progress.update(progress=100)
else:
log.write_line(
f"❌ Preprocessing failed with code {process.returncode}"
)
import os
os.unlink(temp_config)
except Exception as e:
log.write_line(f"❌ Error during preprocessing: {str(e)}")
finally:
self.preprocessing_active = False
self.refresh_dataset_table()
@on(Button.Pressed, "#load-dataset")
async def handle_load_dataset(self) -> None:
"""Load a new dataset."""
log = self.query_one("#processing-log", Log)
log.write_line("📦 Load dataset functionality coming soon...")
@on(Button.Pressed, "#download")
@work(thread=True)
async def handle_download(self) -> None:
"""Download a remote dataset."""
if not self.selected_dataset:
return
dataset_info = self.datasets[self.selected_dataset]
if dataset_info["type"] != "huggingface":
return
log = self.query_one("#processing-log", Log)
log.clear()
log.write_line(f"📥 Downloading {self.selected_dataset} from HuggingFace...")
try:
from datasets import load_dataset
dataset = load_dataset(dataset_info["path"])
save_path = Path(f"/workspace/datasets/{self.selected_dataset}")
save_path.mkdir(parents=True, exist_ok=True)
dataset.save_to_disk(str(save_path))
log.write_line(f"✅ Downloaded to {save_path}")
dataset_info["type"] = "local"
dataset_info["status"] = "available"
dataset_info["path"] = str(save_path)
self.refresh_dataset_table()
except Exception as e:
log.write_line(f"❌ Download failed: {str(e)}")
@on(Button.Pressed, "#refresh")
def handle_refresh(self) -> None:
"""Refresh dataset list."""
self.load_datasets()
def action_preprocess(self) -> None:
"""Preprocess selected dataset."""
self.handle_preprocess()
def action_refresh(self) -> None:
"""Refresh dataset list."""
self.handle_refresh()

View File

@@ -1,445 +0,0 @@
"""Inference and testing screen for Axolotl TUI."""
from pathlib import Path
from typing import Dict, List, Optional
from textual import events, on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.widgets import (
Button,
Input,
Label,
Log,
Select,
Static,
TextArea,
)
from axolotl.tui.screens.base import BaseScreen
class InferenceScreen(BaseScreen):
"""Inference and testing screen."""
BINDINGS = [
Binding("ctrl+enter", "send_message", "Send"),
Binding("ctrl+c", "clear_chat", "Clear"),
Binding("ctrl+l", "load_model", "Load Model"),
Binding("ctrl+s", "save_chat", "Save Chat"),
]
CSS = """
.inference-container {
layout: horizontal;
height: 100%;
}
.model-selector {
width: 30%;
border: solid $primary;
padding: 1;
margin: 1;
}
.chat-interface {
width: 70%;
border: solid $secondary;
padding: 1;
margin: 1;
}
.chat-history {
height: 70%;
border: solid $primary;
padding: 1;
margin: 0 0 1 0;
}
.input-area {
height: 20%;
border: solid $warning;
padding: 1;
margin: 0 0 1 0;
}
.chat-controls {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
}
.chat-controls Button {
margin: 0 1;
}
.model-info {
padding: 1;
border: solid $surface;
margin: 1 0;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
TextArea {
height: 100%;
}
Log {
height: 100%;
}
"""
def __init__(self):
"""Initialize the inference screen."""
super().__init__(
title="Inference & Testing", subtitle="Interactive chat and model testing"
)
self.loaded_model: Optional[str] = None
self.chat_history: List[Dict[str, str]] = []
def compose(self) -> ComposeResult:
"""Compose the inference screen layout."""
yield Container(
Static("🦾 Inference & Testing", classes="screen-title"),
Static("Interactive chat and model testing", classes="screen-subtitle"),
Container(
Container(
Label("Model Selection"),
Select(
[("No model loaded", "none")],
id="model-select",
value="none",
),
Container(
Button("Load Model", id="load-model", variant="primary"),
Button("Unload", id="unload-model", variant="default"),
Button("Gradio UI", id="gradio-ui", variant="success"),
),
Container(
Static("No model loaded", id="model-status"),
classes="model-info",
),
Label("Inference Parameters"),
Container(
Label("Temperature:"),
Input(value="0.7", id="temperature"),
Label("Max Tokens:"),
Input(value="256", id="max-tokens"),
Label("Top P:"),
Input(value="0.9", id="top-p"),
),
classes="model-selector",
),
Container(
Container(
Log(id="chat-history"),
classes="chat-history",
),
Container(
TextArea(
id="message-input",
),
classes="input-area",
),
Container(
Button("Send [Ctrl+Enter]", id="send", variant="primary"),
Button("Clear Chat", id="clear", variant="warning"),
Button("Save Chat", id="save-chat", variant="default"),
Button("Load Examples", id="load-examples", variant="default"),
classes="chat-controls",
),
classes="chat-interface",
),
classes="inference-container",
),
id="content",
)
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.load_available_models()
chat = self.query_one("#chat-history", Log)
chat.write_line("💬 Welcome to Axolotl Inference!")
chat.write_line("Load a model to start chatting.")
@work(thread=True)
async def load_available_models(self) -> None:
"""Load list of available models."""
models = [("No model loaded", "none")]
chat = self.query_one("#chat-history", Log)
chat.write_line("🔍 Scanning for available models...")
# Check for trained models
outputs_dir = Path("./outputs")
chat.write_line(f"Checking outputs directory: {outputs_dir.absolute()}")
if outputs_dir.exists():
found_models = 0
for model_dir in outputs_dir.glob("*"):
if model_dir.is_dir():
# Look for various model file types
model_files = (
list(model_dir.glob("pytorch_model.bin"))
+ list(model_dir.glob("model.safetensors"))
+ list(model_dir.glob("*.bin"))
+ list(model_dir.glob("*.safetensors"))
)
if model_files:
models.append((model_dir.name, str(model_dir)))
found_models += 1
chat.write_line(f"Found {found_models} trained models in outputs/")
else:
chat.write_line("outputs/ directory not found")
# Add some example/demo models for testing
models.extend(
[
("Demo: GPT-2 Small", "gpt2"),
("Demo: TinyLlama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"),
("Demo: Phi-2", "microsoft/phi-2"),
]
)
select = self.query_one("#model-select", Select)
select.set_options(models)
chat.write_line(f"✅ Loaded {len(models)} models in dropdown")
@on(Button.Pressed, "#load-model")
@work(thread=True)
async def handle_load_model(self) -> None:
"""Load selected model for inference."""
select = self.query_one("#model-select", Select)
if select.value == "none":
return
chat = self.query_one("#chat-history", Log)
chat.write_line(f"🔄 Loading model: {select.value}")
status = self.query_one("#model-status", Static)
status.update("Loading...")
try:
# Simulate model loading (in real implementation, would load the actual model)
import time
time.sleep(2) # Simulate loading time
self.loaded_model = select.value
status.update(f"✅ Loaded: {Path(select.value).name}")
chat.write_line("✅ Model loaded successfully!")
chat.write_line("You can now start chatting.")
except Exception as e:
status.update("❌ Failed to load")
chat.write_line(f"❌ Failed to load model: {str(e)}")
@on(Button.Pressed, "#send")
async def handle_send_message(self) -> None:
"""Send message to model."""
if not self.loaded_model:
chat = self.query_one("#chat-history", Log)
chat.write_line("⚠️ Please load a model first")
return
message_input = self.query_one("#message-input", TextArea)
message = message_input.text.strip()
if not message:
return
# Add user message to chat
chat = self.query_one("#chat-history", Log)
chat.write_line(f"👤 User: {message}")
# Clear input
message_input.clear()
# Add to history
self.chat_history.append({"role": "user", "content": message})
# Generate response (placeholder)
self.generate_response(message)
@on(TextArea.Changed, "#message-input")
def on_message_input_changed(self, event: TextArea.Changed) -> None:
"""Handle changes to the message input."""
# This could be used for features like typing indicators
pass
def on_key(self, event: events.Key) -> None:
"""Handle key events globally."""
# Check if we're focused on the message input and Ctrl+Enter is pressed
focused = self.focused
if focused and focused.id == "message-input" and event.key == "ctrl+enter":
event.prevent_default()
self.handle_send_message()
@work(thread=True)
async def generate_response(self, message: str) -> None:
"""Generate model response."""
chat = self.query_one("#chat-history", Log)
chat.write_line("🤖 Assistant: Thinking...")
try:
# Get inference parameters
float(self.query_one("#temperature", Input).value)
int(self.query_one("#max-tokens", Input).value)
float(self.query_one("#top-p", Input).value)
if not self.loaded_model or self.loaded_model == "none":
response = "I don't have a model loaded yet. Please load a model first using the 'Load Model' button."
elif self.loaded_model.startswith("gpt2"):
# Simple response for GPT-2
responses = [
f"Thanks for your message: '{message}'. I'm a GPT-2 model running in demo mode.",
"I understand you're testing the interface. GPT-2 models are great for experimentation!",
"This is a simulated GPT-2 response. In a real setup, I'd generate text based on your input.",
f"GPT-2 here! You said: '{message}'. I'd normally continue this conversation creatively.",
]
import random
response = random.choice(responses)
elif "llama" in self.loaded_model.lower():
# Response for Llama models
response = f"🦙 LLaMA model here! You asked: '{message}'. I'm designed for helpful, harmless, and honest conversations. How can I assist you today?"
elif "phi" in self.loaded_model.lower():
# Response for Phi models
response = f"Phi model responding! Your message: '{message}'. I'm optimized for reasoning and code tasks. What would you like to explore?"
else:
# Generic response for other models
response = f"Model '{self.loaded_model}' responding to: '{message}'. I'm ready to help with your questions!"
# Simulate inference time
import time
time.sleep(0.5)
# Clear the "thinking" message and show response
chat.write_line(f"🤖 Assistant: {response}")
# Add to history
self.chat_history.append({"role": "assistant", "content": response})
except Exception as e:
chat.write_line(f"❌ Error generating response: {str(e)}")
@on(Button.Pressed, "#clear")
def handle_clear_chat(self) -> None:
"""Clear chat history."""
chat = self.query_one("#chat-history", Log)
chat.clear()
self.chat_history = []
chat.write_line("💬 Chat cleared. Start a new conversation!")
@on(Button.Pressed, "#save-chat")
def handle_save_chat(self) -> None:
"""Save chat history to file."""
if not self.chat_history:
chat = self.query_one("#chat-history", Log)
chat.write_line("⚠️ No chat history to save")
return
try:
import json
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"chat_history_{timestamp}.json"
with open(filename, "w") as f:
json.dump(self.chat_history, f, indent=2)
chat = self.query_one("#chat-history", Log)
chat.write_line(f"💾 Chat saved to {filename}")
except Exception as e:
chat = self.query_one("#chat-history", Log)
chat.write_line(f"❌ Error saving chat: {str(e)}")
@on(Button.Pressed, "#load-examples")
def handle_load_examples(self) -> None:
"""Load example prompts."""
examples = [
"Explain the concept of machine learning in simple terms.",
"Write a Python function to calculate fibonacci numbers.",
"What are the benefits of fine-tuning language models?",
"Describe the difference between supervised and unsupervised learning.",
]
chat = self.query_one("#chat-history", Log)
chat.write_line("📚 Example prompts:")
for i, example in enumerate(examples, 1):
chat.write_line(f"{i}. {example}")
chat.write_line("Copy and paste any example to try it out!")
@on(Button.Pressed, "#gradio-ui")
@work(thread=True)
async def handle_gradio_ui(self) -> None:
"""Launch Gradio web interface."""
chat = self.query_one("#chat-history", Log)
chat.write_line("🌐 Launching Gradio web interface...")
try:
import subprocess
if self.loaded_model:
cmd = [
"python",
"-m",
"axolotl.cli.inference",
self.loaded_model,
"--gradio",
]
else:
chat.write_line("⚠️ No model loaded. Loading default interface...")
cmd = ["python", "-m", "axolotl.cli.inference", "--gradio"]
subprocess.Popen(cmd)
chat.write_line("✅ Gradio interface launched! Check your browser.")
except Exception as e:
chat.write_line(f"❌ Error launching Gradio: {str(e)}")
@on(Button.Pressed, "#unload-model")
def handle_unload_model(self) -> None:
"""Unload current model."""
self.loaded_model = None
status = self.query_one("#model-status", Static)
status.update("No model loaded")
select = self.query_one("#model-select", Select)
select.value = "none"
chat = self.query_one("#chat-history", Log)
chat.write_line("🔄 Model unloaded")
def action_send_message(self) -> None:
"""Send message action."""
self.handle_send_message()
def action_clear_chat(self) -> None:
"""Clear chat action."""
self.handle_clear_chat()
def action_load_model(self) -> None:
"""Load model action."""
self.handle_load_model()
def action_save_chat(self) -> None:
"""Save chat action."""
self.handle_save_chat()

View File

@@ -1,373 +0,0 @@
"""Model management screen for Axolotl TUI."""
from pathlib import Path
from typing import Dict, Optional
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container, ScrollableContainer
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Label,
Log,
ProgressBar,
Static,
TabbedContent,
TabPane,
)
from axolotl.tui.screens.base import BaseScreen
class ModelScreen(BaseScreen):
"""Model management screen."""
BINDINGS = [
Binding("ctrl+m", "merge_lora", "Merge LoRA"),
Binding("ctrl+q", "quantize", "Quantize"),
Binding("ctrl+e", "evaluate", "Evaluate"),
Binding("r", "refresh", "Refresh"),
]
CSS = """
.model-container {
layout: horizontal;
height: 100%;
}
.model-list {
width: 50%;
border: solid $primary;
padding: 1;
margin: 1;
}
.model-operations {
width: 50%;
border: solid $secondary;
padding: 1;
margin: 1;
}
.model-actions {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
}
.model-actions Button {
margin: 0 1;
}
DataTable {
height: 80%;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
"""
def __init__(self):
"""Initialize the model screen."""
super().__init__(
title="Model Management",
subtitle="Manage trained models, merge LoRA adapters, and quantize models",
)
self.models: Dict[str, Dict] = {}
self.selected_model: Optional[str] = None
def compose(self) -> ComposeResult:
"""Compose the model screen layout."""
yield Header()
with Container(id="content"):
yield Static("🦾 Model Management", classes="screen-title")
yield Static(
"Manage trained models, merge LoRA adapters, and quantize models",
classes="screen-subtitle",
)
with Container(classes="model-container"):
with Container(classes="model-list"):
yield Label("Available Models")
yield DataTable(id="model-table")
with Container(classes="model-actions"):
yield Button("Merge LoRA", id="merge-lora", variant="primary")
yield Button("Quantize", id="quantize", variant="success")
yield Button("Evaluate", id="evaluate", variant="warning")
yield Button("Refresh", id="refresh", variant="default")
with Container(classes="model-operations"):
with TabbedContent():
with TabPane("Operations"):
with Container():
yield Log(id="operations-log")
with Container():
yield Label("Operation Progress:")
yield ProgressBar(
total=100,
id="operation-progress",
)
with TabPane("Model Info"):
with ScrollableContainer():
yield Static(
"Model information will appear here",
id="model-info",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.setup_model_table()
self.load_models()
log = self.query_one("#operations-log", Log)
log.write_line("Model manager ready.")
def setup_model_table(self) -> None:
"""Setup the model table."""
table = self.query_one("#model-table", DataTable)
table.add_columns("Name", "Type", "Size", "Status")
table.cursor_type = "row"
table.zebra_stripes = True
@work(thread=True)
async def load_models(self) -> None:
"""Load available models."""
# Check outputs directory for trained models
outputs_dir = Path("./outputs")
if outputs_dir.exists():
for model_dir in outputs_dir.glob("*"):
if model_dir.is_dir():
self.models[model_dir.name] = {
"name": model_dir.name,
"path": str(model_dir),
"type": "checkpoint",
"size": self.get_dir_size(model_dir),
"status": "available",
}
self.refresh_model_table()
def get_dir_size(self, path: Path) -> str:
"""Get human-readable directory size."""
try:
total_size = sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
for unit in ["B", "KB", "MB", "GB"]:
if total_size < 1024.0:
return f"{total_size:.2f} {unit}"
total_size /= 1024.0
return f"{total_size:.2f} TB"
except Exception:
return "Unknown"
def refresh_model_table(self) -> None:
"""Refresh the model table."""
table = self.query_one("#model-table", DataTable)
table.clear()
for name, info in self.models.items():
table.add_row(
name[:30],
info["type"],
info["size"],
info["status"],
)
@on(DataTable.RowSelected)
def handle_model_selected(self, event: DataTable.RowSelected) -> None:
"""Handle model selection from table."""
if event.cursor_row >= 0:
model_names = list(self.models.keys())
if event.cursor_row < len(model_names):
self.selected_model = model_names[event.cursor_row]
self.update_model_info()
def update_model_info(self) -> None:
"""Update model information display."""
if not self.selected_model:
return
info = self.models[self.selected_model]
info_text = f"""
Model Name: {info['name']}
Path: {info['path']}
Type: {info['type']}
Size: {info['size']}
Status: {info['status']}
"""
self.query_one("#model-info", Static).update(info_text)
@on(Button.Pressed, "#merge-lora")
@work(thread=True)
async def handle_merge_lora(self) -> None:
"""Merge LoRA adapters with base model."""
if not self.selected_model:
log = self.query_one("#operations-log", Log)
log.write_line("⚠️ No model selected")
return
model_info = self.models[self.selected_model]
log = self.query_one("#operations-log", Log)
log.clear()
log.write_line(f"🔄 Merging LoRA adapters for {self.selected_model}...")
progress = self.query_one("#operation-progress", ProgressBar)
progress.update(progress=0)
try:
import subprocess
cmd = ["python", "-m", "axolotl.cli.merge_lora", model_info["path"]]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
for line in process.stdout:
log.write_line(line.strip())
progress.advance(10)
process.wait()
if process.returncode == 0:
log.write_line("✅ LoRA merge completed successfully!")
progress.update(progress=100)
else:
log.write_line(f"❌ LoRA merge failed with code {process.returncode}")
except Exception as e:
log.write_line(f"❌ Error during LoRA merge: {str(e)}")
@on(Button.Pressed, "#quantize")
@work(thread=True)
async def handle_quantize(self) -> None:
"""Quantize selected model."""
if not self.selected_model:
log = self.query_one("#operations-log", Log)
log.write_line("⚠️ No model selected")
return
model_info = self.models[self.selected_model]
log = self.query_one("#operations-log", Log)
log.clear()
log.write_line(f"🔄 Quantizing {self.selected_model}...")
progress = self.query_one("#operation-progress", ProgressBar)
progress.update(progress=0)
try:
import subprocess
cmd = [
"python",
"-m",
"axolotl.cli.quantize",
model_info["path"],
"--output-dir",
f"{model_info['path']}_quantized",
]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
for line in process.stdout:
log.write_line(line.strip())
progress.advance(5)
process.wait()
if process.returncode == 0:
log.write_line("✅ Quantization completed successfully!")
progress.update(progress=100)
else:
log.write_line(f"❌ Quantization failed with code {process.returncode}")
except Exception as e:
log.write_line(f"❌ Error during quantization: {str(e)}")
@on(Button.Pressed, "#evaluate")
@work(thread=True)
async def handle_evaluate(self) -> None:
"""Evaluate selected model."""
if not self.selected_model:
log = self.query_one("#operations-log", Log)
log.write_line("⚠️ No model selected")
return
model_info = self.models[self.selected_model]
log = self.query_one("#operations-log", Log)
log.clear()
log.write_line(f"🔄 Evaluating {self.selected_model}...")
progress = self.query_one("#operation-progress", ProgressBar)
progress.update(progress=0)
try:
import subprocess
cmd = ["python", "-m", "axolotl.cli.evaluate", model_info["path"]]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
for line in process.stdout:
log.write_line(line.strip())
progress.advance(10)
process.wait()
if process.returncode == 0:
log.write_line("✅ Evaluation completed successfully!")
progress.update(progress=100)
else:
log.write_line(f"❌ Evaluation failed with code {process.returncode}")
except Exception as e:
log.write_line(f"❌ Error during evaluation: {str(e)}")
@on(Button.Pressed, "#refresh")
def handle_refresh(self) -> None:
"""Refresh model list."""
self.load_models()
def action_merge_lora(self) -> None:
"""Merge LoRA adapters."""
self.handle_merge_lora()
def action_quantize(self) -> None:
"""Quantize model."""
self.handle_quantize()
def action_evaluate(self) -> None:
"""Evaluate model."""
self.handle_evaluate()
def action_refresh(self) -> None:
"""Refresh model list."""
self.handle_refresh()

View File

@@ -1,414 +0,0 @@
"""System monitoring screen for Axolotl TUI."""
import psutil
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Label,
Log,
ProgressBar,
Sparkline,
Static,
)
from axolotl.tui.screens.base import BaseScreen
class MonitorScreen(BaseScreen):
"""System monitoring screen."""
BINDINGS = [
Binding("r", "refresh", "Refresh"),
Binding("ctrl+k", "kill_process", "Kill Process"),
]
CSS = """
.monitor-container {
layout: vertical;
height: 100%;
}
.metrics-grid {
layout: horizontal;
height: 20%;
padding: 1;
}
.metric-card {
width: 25%;
border: solid $surface;
padding: 1;
margin: 0 1;
}
.metric-label {
text-style: bold;
color: $text-muted;
text-align: center;
}
.metric-value {
text-style: bold;
text-align: center;
padding: 1;
}
.charts-container {
height: 40%;
layout: horizontal;
padding: 1;
}
.chart-panel {
width: 50%;
border: solid $primary;
padding: 1;
margin: 0 1;
}
.processes-container {
height: 40%;
border: solid $warning;
padding: 1;
margin: 1;
}
DataTable {
height: 90%;
}
.process-controls {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
}
.process-controls Button {
margin: 0 1;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
Sparkline {
height: 8;
}
ProgressBar {
margin: 1 0;
}
"""
def __init__(self):
"""Initialize the monitor screen."""
super().__init__(
title="System Monitor",
subtitle="Monitor system resources and running processes",
)
self.cpu_history = []
self.memory_history = []
self.gpu_history = []
def compose(self) -> ComposeResult:
"""Compose the monitor screen layout."""
yield Header()
yield Container(
Static("🦾 System Monitor", classes="screen-title"),
Static(
"Monitor system resources and running processes",
classes="screen-subtitle",
),
Container(
Container(
Container(
Static("CPU Usage", classes="metric-label"),
Static("0%", id="cpu-usage", classes="metric-value"),
ProgressBar(total=100, id="cpu-progress"),
classes="metric-card",
),
Container(
Static("Memory", classes="metric-label"),
Static("0%", id="memory-usage", classes="metric-value"),
ProgressBar(total=100, id="memory-progress"),
classes="metric-card",
),
Container(
Static("GPU Usage", classes="metric-label"),
Static("0%", id="gpu-usage", classes="metric-value"),
ProgressBar(total=100, id="gpu-progress"),
classes="metric-card",
),
Container(
Static("Temperature", classes="metric-label"),
Static("0°C", id="temperature", classes="metric-value"),
classes="metric-card",
),
classes="metrics-grid",
),
Container(
Container(
Label("CPU History"),
Sparkline([], id="cpu-sparkline"),
classes="chart-panel",
),
Container(
Label("Memory History"),
Sparkline([], id="memory-sparkline"),
classes="chart-panel",
),
classes="charts-container",
),
Container(
DataTable(id="process-table"),
Log(id="gpu-info"),
Log(id="system-logs"),
classes="processes-container",
),
classes="monitor-container",
),
id="content",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.setup_process_table()
self.start_monitoring()
# Initial system info
self.update_system_info()
self.update_gpu_info()
def setup_process_table(self) -> None:
"""Setup the process table."""
table = self.query_one("#process-table", DataTable)
table.add_columns("PID", "Name", "CPU%", "Memory%", "Status")
table.cursor_type = "row"
table.zebra_stripes = True
def start_monitoring(self) -> None:
"""Start the monitoring timer."""
self.set_interval(2.0, self.update_system_metrics)
@work(thread=True)
async def update_system_metrics(self) -> None:
"""Update system metrics."""
try:
# CPU usage
cpu_percent = psutil.cpu_percent(interval=None)
self.cpu_history.append(cpu_percent)
if len(self.cpu_history) > 50:
self.cpu_history.pop(0)
# Memory usage
memory = psutil.virtual_memory()
memory_percent = memory.percent
self.memory_history.append(memory_percent)
if len(self.memory_history) > 50:
self.memory_history.pop(0)
# GPU usage (if available)
gpu_percent = self.get_gpu_usage()
self.gpu_history.append(gpu_percent)
if len(self.gpu_history) > 50:
self.gpu_history.pop(0)
# Temperature
temperature = self.get_temperature()
# Update UI
self.update_metrics_display(
cpu_percent, memory_percent, gpu_percent, temperature
)
self.update_sparklines()
self.update_process_table()
except Exception as e:
log = self.query_one("#system-logs", Log)
log.write_line(f"Error updating metrics: {str(e)}")
def get_gpu_usage(self) -> float:
"""Get GPU usage percentage."""
try:
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
return util.gpu
except Exception:
return 0.0
def get_temperature(self) -> str:
"""Get system temperature."""
try:
temps = psutil.sensors_temperatures()
if temps:
for name, entries in temps.items():
if entries:
return f"{entries[0].current:.1f}°C"
return "N/A"
except Exception:
return "N/A"
def update_metrics_display(
self, cpu: float, memory: float, gpu: float, temp: str
) -> None:
"""Update metrics display."""
self.query_one("#cpu-usage", Static).update(f"{cpu:.1f}%")
self.query_one("#memory-usage", Static).update(f"{memory:.1f}%")
self.query_one("#gpu-usage", Static).update(f"{gpu:.1f}%")
self.query_one("#temperature", Static).update(temp)
self.query_one("#cpu-progress", ProgressBar).update(progress=cpu)
self.query_one("#memory-progress", ProgressBar).update(progress=memory)
self.query_one("#gpu-progress", ProgressBar).update(progress=gpu)
def update_sparklines(self) -> None:
"""Update sparkline charts."""
if self.cpu_history:
cpu_sparkline = self.query_one("#cpu-sparkline", Sparkline)
cpu_sparkline.data = self.cpu_history
if self.memory_history:
memory_sparkline = self.query_one("#memory-sparkline", Sparkline)
memory_sparkline.data = self.memory_history
def update_process_table(self) -> None:
"""Update the process table."""
table = self.query_one("#process-table", DataTable)
table.clear()
try:
# Get top processes by CPU usage
processes = []
for proc in psutil.process_iter(
["pid", "name", "cpu_percent", "memory_percent", "status"]
):
try:
pinfo = proc.info
if pinfo["cpu_percent"] > 0.1: # Only show processes using CPU
processes.append(pinfo)
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
# Sort by CPU usage
processes.sort(key=lambda x: x["cpu_percent"], reverse=True)
# Add top 20 processes
for proc in processes[:20]:
table.add_row(
str(proc["pid"]),
proc["name"][:20],
f"{proc['cpu_percent']:.1f}%",
f"{proc['memory_percent']:.1f}%",
proc["status"],
)
except Exception as e:
log = self.query_one("#system-logs", Log)
log.write_line(f"Error updating process table: {str(e)}")
def update_system_info(self) -> None:
"""Update system information."""
try:
# System info
psutil.boot_time()
cpu_count = psutil.cpu_count()
memory = psutil.virtual_memory()
log = self.query_one("#system-logs", Log)
log.write_line(f"System started. CPU cores: {cpu_count}")
log.write_line(f"Total memory: {memory.total / (1024**3):.1f} GB")
log.write_line(f"Available memory: {memory.available / (1024**3):.1f} GB")
except Exception as e:
log = self.query_one("#system-logs", Log)
log.write_line(f"Error getting system info: {str(e)}")
def update_gpu_info(self) -> None:
"""Update GPU information."""
try:
import pynvml
pynvml.nvmlInit()
device_count = pynvml.nvmlDeviceGetCount()
log = self.query_one("#gpu-info", Log)
log.clear()
log.write_line(f"Found {device_count} GPU(s)")
for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
name = pynvml.nvmlDeviceGetName(handle).decode()
memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
log.write_line(f"\nGPU {i}: {name}")
log.write_line(
f"Memory: {memory_info.used / (1024**3):.1f} / {memory_info.total / (1024**3):.1f} GB"
)
log.write_line(f"Free: {memory_info.free / (1024**3):.1f} GB")
except Exception as e:
log = self.query_one("#gpu-info", Log)
log.clear()
log.write_line(f"GPU info unavailable: {str(e)}")
@on(Button.Pressed, "#kill-process")
def handle_kill_process(self) -> None:
"""Kill selected process."""
table = self.query_one("#process-table", DataTable)
if table.cursor_row >= 0:
try:
row = table.get_row_at(table.cursor_row)
pid = int(row[0])
process = psutil.Process(pid)
process.terminate()
log = self.query_one("#system-logs", Log)
log.write_line(f"Terminated process {pid}")
except Exception as e:
log = self.query_one("#system-logs", Log)
log.write_line(f"Error killing process: {str(e)}")
@on(Button.Pressed, "#refresh")
def handle_refresh(self) -> None:
"""Refresh all metrics."""
self.update_system_info()
self.update_gpu_info()
log = self.query_one("#system-logs", Log)
log.write_line("Metrics refreshed")
@on(Button.Pressed, "#auto-refresh")
def handle_auto_refresh(self) -> None:
"""Toggle auto refresh."""
log = self.query_one("#system-logs", Log)
log.write_line("Auto refresh is always enabled (every 2 seconds)")
def action_refresh(self) -> None:
"""Refresh action."""
self.handle_refresh()
def action_kill_process(self) -> None:
"""Kill process action."""
self.handle_kill_process()

View File

@@ -1,545 +0,0 @@
"""Training management screen for Axolotl TUI."""
import subprocess
import threading
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Label,
Log,
Sparkline,
Static,
)
from axolotl.tui.screens.base import BaseScreen
@dataclass
class TrainingJob:
"""Represents a training job."""
id: str
config_path: str
status: str # pending, running, completed, failed
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
process: Optional[subprocess.Popen] = None
log_file: Optional[str] = None
current_epoch: int = 0
total_epochs: int = 0
current_loss: float = 0.0
losses: List[float] = None
def __post_init__(self):
if self.losses is None:
self.losses = []
class TrainingScreen(BaseScreen):
"""Training management screen."""
BINDINGS = [
Binding("ctrl+t", "new_training", "New Training"),
Binding("ctrl+r", "resume_training", "Resume"),
Binding("ctrl+x", "stop_training", "Stop"),
Binding("ctrl+l", "view_logs", "View Logs"),
Binding("r", "refresh", "Refresh"),
]
CSS = """
.training-container {
layout: vertical;
height: 100%;
}
.job-list-container {
height: 40%;
border: solid $primary;
padding: 1;
margin: 1;
}
.job-details-container {
height: 60%;
padding: 1;
}
.control-panel {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
border: solid $secondary;
margin: 1;
}
.control-panel Button {
margin: 0 1;
}
.metrics-panel {
layout: horizontal;
height: 10;
border: solid $primary;
padding: 1;
margin: 1;
}
.metric-card {
width: 25%;
border: tall $surface;
padding: 1;
margin: 0 1;
}
.metric-label {
text-style: bold;
color: $text-muted;
}
.metric-value {
text-style: bold;
text-align: center;
padding: 1;
}
.log-viewer {
border: solid $warning;
padding: 1;
margin: 1;
}
#training-logs {
height: 100%;
}
DataTable {
height: 100%;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
.sparkline-container {
height: 5;
border: solid $success;
padding: 1;
margin: 1;
}
"""
def __init__(self):
"""Initialize the training screen."""
super().__init__(
title="Training Management",
subtitle="Launch, monitor, and manage training jobs",
)
self.jobs: Dict[str, TrainingJob] = {}
self.selected_job_id: Optional[str] = None
self.update_timer = None
def compose(self) -> ComposeResult:
"""Compose the training screen layout."""
yield Header()
yield Container(
Static("🦾 Training Management", classes="screen-title"),
Static(
"Launch, monitor, and manage training jobs", classes="screen-subtitle"
),
Container(
Container(
Label("Active Training Jobs"),
DataTable(id="job-table"),
classes="job-list-container",
),
Container(
Button("New Training", id="new-training", variant="primary"),
Button("Resume", id="resume-training", variant="success"),
Button("Stop", id="stop-training", variant="error"),
Button("View Logs", id="view-logs", variant="default"),
Button("Clear Completed", id="clear-completed", variant="warning"),
Button("Refresh", id="refresh", variant="default"),
classes="control-panel",
),
Container(
Container(
Static("Current Epoch", classes="metric-label"),
Static("0 / 0", id="epoch-metric", classes="metric-value"),
classes="metric-card",
),
Container(
Static("Loss", classes="metric-label"),
Static("0.000", id="loss-metric", classes="metric-value"),
classes="metric-card",
),
Container(
Static("Status", classes="metric-label"),
Static("Idle", id="status-metric", classes="metric-value"),
classes="metric-card",
),
Container(
Static("Duration", classes="metric-label"),
Static(
"00:00:00", id="duration-metric", classes="metric-value"
),
classes="metric-card",
),
classes="metrics-panel",
),
Container(
Label("Loss History"),
Sparkline(
[],
id="loss-sparkline",
summary_function=min,
),
classes="sparkline-container",
),
Container(
Log(id="training-logs"),
classes="log-viewer",
),
classes="job-details-container",
),
classes="training-container",
id="content",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.setup_job_table()
self.start_update_timer()
log = self.query_one("#training-logs", Log)
log.write_line(
"Training manager ready. Select a configuration to start training."
)
def setup_job_table(self) -> None:
"""Setup the job table."""
table = self.query_one("#job-table", DataTable)
table.add_columns("ID", "Config", "Status", "Epoch", "Loss", "Duration")
table.cursor_type = "row"
table.zebra_stripes = True
def start_update_timer(self) -> None:
"""Start the periodic update timer."""
self.set_interval(2.0, self.update_job_status)
@work(thread=True)
async def update_job_status(self) -> None:
"""Update job status periodically."""
for job_id, job in self.jobs.items():
if job.status == "running" and job.process:
poll = job.process.poll()
if poll is not None:
if poll == 0:
job.status = "completed"
else:
job.status = "failed"
job.end_time = datetime.now()
self.refresh_job_table()
self.update_selected_job_metrics()
def refresh_job_table(self) -> None:
"""Refresh the job table."""
table = self.query_one("#job-table", DataTable)
table.clear()
for job_id, job in self.jobs.items():
duration = self.calculate_duration(job)
table.add_row(
job_id[:8],
Path(job.config_path).name,
job.status,
f"{job.current_epoch}/{job.total_epochs}",
f"{job.current_loss:.4f}" if job.current_loss else "N/A",
duration,
)
def calculate_duration(self, job: TrainingJob) -> str:
"""Calculate job duration."""
if not job.start_time:
return "00:00:00"
end_time = job.end_time or datetime.now()
duration = end_time - job.start_time
hours = int(duration.total_seconds() // 3600)
minutes = int((duration.total_seconds() % 3600) // 60)
seconds = int(duration.total_seconds() % 60)
return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
def update_selected_job_metrics(self) -> None:
"""Update metrics for selected job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
return
job = self.jobs[self.selected_job_id]
self.query_one("#epoch-metric", Static).update(
f"{job.current_epoch} / {job.total_epochs}"
)
self.query_one("#loss-metric", Static).update(
f"{job.current_loss:.4f}" if job.current_loss else "N/A"
)
self.query_one("#status-metric", Static).update(job.status.upper())
self.query_one("#duration-metric", Static).update(self.calculate_duration(job))
if job.losses:
sparkline = self.query_one("#loss-sparkline", Sparkline)
sparkline.data = job.losses[-50:] # Show last 50 loss values
@on(DataTable.RowSelected)
def handle_row_selected(self, event: DataTable.RowSelected) -> None:
"""Handle job selection from table."""
if event.cursor_row >= 0:
job_ids = list(self.jobs.keys())
if event.cursor_row < len(job_ids):
self.selected_job_id = job_ids[event.cursor_row]
self.update_selected_job_metrics()
self.load_job_logs()
def load_job_logs(self) -> None:
"""Load logs for selected job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
return
job = self.jobs[self.selected_job_id]
if job.log_file and Path(job.log_file).exists():
try:
with open(job.log_file, "r") as f:
content = f.read()
log = self.query_one("#training-logs", Log)
log.clear()
for line in content.split("\n")[-100:]: # Show last 100 lines
if line.strip():
log.write_line(line)
except Exception as e:
log = self.query_one("#training-logs", Log)
log.write_line(f"Error loading logs: {str(e)}")
@on(Button.Pressed, "#new-training")
async def handle_new_training(self) -> None:
"""Start a new training job."""
from axolotl.tui.dialogs.training import NewTrainingDialog
dialog = NewTrainingDialog()
result = await self.app.push_screen_wait(dialog)
if result and "config_path" in result:
await self.start_training_job(
result["config_path"], result.get("launcher", "accelerate")
)
@work(thread=True)
async def start_training_job(
self, config_path: str, launcher: str = "accelerate"
) -> None:
"""Start a training job."""
import uuid
from datetime import datetime
job_id = str(uuid.uuid4())
log_file = f"/tmp/axolotl_training_{job_id}.log"
job = TrainingJob(
id=job_id,
config_path=config_path,
status="pending",
start_time=datetime.now(),
log_file=log_file,
total_epochs=3, # Default, should parse from config
)
self.jobs[job_id] = job
self.selected_job_id = job_id
log = self.query_one("#training-logs", Log)
log.clear()
log.write_line(f"🚀 Starting training job {job_id[:8]}...")
log.write_line(f"Config: {config_path}")
log.write_line(f"Launcher: {launcher}")
try:
if launcher == "accelerate":
cmd = ["accelerate", "launch", "-m", "axolotl.cli.train", config_path]
else:
cmd = [
"torchrun",
"--nproc_per_node=1",
"-m",
"axolotl.cli.train",
config_path,
]
with open(log_file, "w") as f:
process = subprocess.Popen(
cmd,
stdout=f,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
)
job.process = process
job.status = "running"
log.write_line("✅ Training started successfully!")
self.refresh_job_table()
self.monitor_training_output(job_id)
except Exception as e:
job.status = "failed"
job.end_time = datetime.now()
log.write_line(f"❌ Failed to start training: {str(e)}")
self.refresh_job_table()
def monitor_training_output(self, job_id: str) -> None:
"""Monitor training output and extract metrics."""
if job_id not in self.jobs:
return
job = self.jobs[job_id]
if not job.log_file:
return
def tail_log():
import re
import time
with open(job.log_file, "r") as f:
f.seek(0, 2) # Go to end of file
while job.status == "running":
line = f.readline()
if line:
# Parse training metrics from log
epoch_match = re.search(r"Epoch (\d+)/(\d+)", line)
if epoch_match:
job.current_epoch = int(epoch_match.group(1))
job.total_epochs = int(epoch_match.group(2))
loss_match = re.search(
r"loss['\"]?\s*:\s*([\d.]+)", line, re.IGNORECASE
)
if loss_match:
job.current_loss = float(loss_match.group(1))
job.losses.append(job.current_loss)
# Update log viewer
self.call_from_thread(self.append_training_log, line.strip())
else:
time.sleep(0.5)
thread = threading.Thread(target=tail_log, daemon=True)
thread.start()
def append_training_log(self, line: str) -> None:
"""Append line to training log."""
log = self.query_one("#training-logs", Log)
log.write_line(line)
@on(Button.Pressed, "#stop-training")
def handle_stop_training(self) -> None:
"""Stop selected training job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
log = self.query_one("#training-logs", Log)
log.write_line("⚠️ No job selected")
return
job = self.jobs[self.selected_job_id]
if job.status == "running" and job.process:
job.process.terminate()
job.status = "stopped"
job.end_time = datetime.now()
log = self.query_one("#training-logs", Log)
log.write_line(f"🛑 Training job {job.id[:8]} stopped")
self.refresh_job_table()
@on(Button.Pressed, "#resume-training")
async def handle_resume_training(self) -> None:
"""Resume a stopped training job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
log = self.query_one("#training-logs", Log)
log.write_line("⚠️ No job selected")
return
job = self.jobs[self.selected_job_id]
if job.status in ["stopped", "failed"]:
await self.start_training_job(job.config_path)
@on(Button.Pressed, "#clear-completed")
def handle_clear_completed(self) -> None:
"""Clear completed jobs from the list."""
completed_jobs = [
job_id
for job_id, job in self.jobs.items()
if job.status in ["completed", "failed", "stopped"]
]
for job_id in completed_jobs:
del self.jobs[job_id]
self.refresh_job_table()
log = self.query_one("#training-logs", Log)
log.write_line(f"🧹 Cleared {len(completed_jobs)} completed jobs")
@on(Button.Pressed, "#refresh")
def handle_refresh(self) -> None:
"""Refresh the job list and metrics."""
self.refresh_job_table()
self.update_selected_job_metrics()
if self.selected_job_id:
self.load_job_logs()
@on(Button.Pressed, "#view-logs")
def handle_view_logs(self) -> None:
"""View full logs for selected job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
return
job = self.jobs[self.selected_job_id]
if job.log_file and Path(job.log_file).exists():
import subprocess
subprocess.run(["less", job.log_file])
def action_new_training(self) -> None:
"""Start a new training job."""
self.handle_new_training()
def action_stop_training(self) -> None:
"""Stop selected training job."""
self.handle_stop_training()
def action_resume_training(self) -> None:
"""Resume selected training job."""
self.handle_resume_training()
def action_refresh(self) -> None:
"""Refresh the display."""
self.handle_refresh()

View File

@@ -16,7 +16,7 @@ from packaging.version import Version, parse
def check_cuda_p2p_ib_support():
if not accelerate_check_cuda_p2p_ib_support():
return False
unsupported_devices = {"RTX 6000 Ada", "L40S"}
unsupported_devices = {"RTX 6000 Ada", "L40S", "A40"}
try:
device_names, device_count = get_gpu_info()
if 1 < device_count < 8:

View File

@@ -109,6 +109,12 @@ class AxolotlInputConfig(
"description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs"
},
)
reinit_weights: bool | None = Field(
default=None,
json_schema_extra={
"description": "Reinitialize model weights randomly instead of loading pretrained weights"
},
)
trainer_cls: str | None = Field(
default=None,

119
tests/e2e/test_diffusion.py Normal file
View File

@@ -0,0 +1,119 @@
"""E2E smoke test for diffusion training plugin."""
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists
class TestDiffusion:
"""Test case for diffusion training plugin."""
def test_diffusion_smoke_test(self, temp_dir):
"""
Smoke test for diffusion training to ensure the plugin loads and trains without
error.
"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"trust_remote_code": True,
"sequence_len": 256,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 3,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
"logging_steps": 1,
"eval_steps": 3,
# Diffusion-specific config
"plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"],
"diffusion_mask_token_id": 16,
"diffusion_eps": 1e-3,
"diffusion_importance_weighting": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
def test_diffusion_sft_labels(self, temp_dir):
"""Test that diffusion training properly handles SFT data with labels."""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"trust_remote_code": True,
"sequence_len": 256,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 3,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
"logging_steps": 1,
"eval_steps": 2,
# Diffusion-specific config
"plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"],
"diffusion_mask_token_id": 16,
"diffusion_eps": 1e-3,
"diffusion_importance_weighting": True,
# Ensure we have proper SFT labels
"train_on_inputs": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
# Verify that the dataset has labels
sample = dataset_meta.train_dataset[0]
assert "labels" in sample, "SFT dataset should have labels"
# Check that some labels are -100 (prompt tokens)
labels = sample["labels"]
if hasattr(labels, "tolist"):
labels = labels.tolist()
assert -100 in labels, "SFT dataset should have -100 labels for prompt tokens"
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -0,0 +1,271 @@
"""Tests for diffusion trainer integration."""
# pylint: disable=redefined-outer-name,protected-access
from unittest.mock import Mock
import pytest
import torch
from axolotl.integrations.diffusion.trainer import DiffusionTrainer
from axolotl.utils.dict import DictDefault
@pytest.fixture
def mock_tokenizer():
"""Create a mock tokenizer."""
tokenizer = Mock()
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
tokenizer.pad_token_id = 0
return tokenizer
@pytest.fixture
def diffusion_config():
"""Create a diffusion config."""
return DictDefault(
{
"mask_token_id": 32000,
"eps": 1e-3,
"importance_weighting": False,
"sample_packing": False,
}
)
@pytest.fixture
def diffusion_trainer_instance(mock_tokenizer, diffusion_config):
"""Create a diffusion trainer instance for testing methods directly."""
# Create a minimal trainer instance just for testing methods
trainer = object.__new__(DiffusionTrainer) # Bypass __init__
trainer.config = diffusion_config
trainer._special_token_ids = {0, 1, 2} # pad, bos, eos
trainer.processing_class = mock_tokenizer
trainer.store_metrics = Mock() # Mock metrics storage
return trainer
class TestDiffusionTrainer:
"""Test the DiffusionTrainer class."""
def test_forward_process_basic(self, diffusion_trainer_instance):
"""Test basic forward process without labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
noisy_batch, masked_indices, p_mask = (
diffusion_trainer_instance._forward_process(input_ids, eps=0.1)
)
# Check shapes
assert noisy_batch.shape == input_ids.shape
assert masked_indices.shape == input_ids.shape
assert p_mask.shape == input_ids.shape
# Check that special tokens are not masked
special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0)
assert not masked_indices[special_token_positions].any()
# Check that mask token is applied
mask_token_id = diffusion_trainer_instance._config.mask_token_id
masked_positions = masked_indices
if masked_positions.any():
assert (noisy_batch[masked_positions] == mask_token_id).all()
def test_forward_process_with_labels(self, diffusion_trainer_instance):
"""Test forward process with SFT labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
noisy_batch, masked_indices, p_mask = (
diffusion_trainer_instance._forward_process(
input_ids, labels=labels, eps=0.1
)
)
# Check shapes
assert noisy_batch.shape == input_ids.shape
assert masked_indices.shape == input_ids.shape
assert p_mask.shape == input_ids.shape
# Check that only answer tokens can be masked (where labels != -100)
non_answer_mask = labels == -100
# No masking should occur on non-answer tokens
assert not masked_indices[non_answer_mask].any()
# p_mask should be the same for all positions (sampled timestep),
# but masking is only applied to answer tokens
assert p_mask.shape == input_ids.shape
# Verify that masked_indices respects the answer mask
assert not masked_indices[non_answer_mask].any()
def test_forward_process_with_attention_mask(self, diffusion_trainer_instance):
"""Test forward process with attention mask."""
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
_, masked_indices, p_mask = diffusion_trainer_instance._forward_process(
input_ids, attention_mask=attention_mask, eps=0.1
)
# Check that padding tokens are not masked
padding_positions = attention_mask == 0
assert not masked_indices[padding_positions].any()
assert (p_mask[padding_positions] == 0).all()
def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance):
"""Test bidirectional attention mask without sample packing."""
input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long)
mask = diffusion_trainer_instance._create_bidirectional_attention_mask(
input_ids
)
# Should be all-to-all attention
expected_shape = (1, 1, 4, 4)
assert mask.shape == expected_shape
assert mask.all()
def test_bidirectional_attention_mask_with_packing(
self, diffusion_trainer_instance
):
"""Test bidirectional attention mask with sample packing."""
diffusion_trainer_instance._config.sample_packing = True
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
# Sample IDs: first sample (1), second sample (2)
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
mask = diffusion_trainer_instance._create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Check that tokens within same sample can attend to each other
# but not across samples
assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other
assert mask[0, 0, 1, 2].item()
assert not mask[0, 0, 0, 3].item() # Can't attend across samples
assert not mask[0, 0, 2, 4].item()
assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other
def test_compute_loss_basic(self, diffusion_trainer_instance):
"""Test basic loss computation."""
# Mock model that returns logits
mock_model = Mock()
mock_outputs = Mock()
vocab_size = 1000
seq_len = 5
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
mock_model.return_value = mock_outputs
mock_model.training = True
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
loss, outputs = diffusion_trainer_instance._compute_diffusion_loss(
mock_model, input_ids
)
# Check that loss is computed
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
assert outputs == mock_outputs
# Check that metrics were stored
diffusion_trainer_instance.store_metrics.assert_called_once()
def test_compute_loss_with_labels(self, diffusion_trainer_instance):
"""Test loss computation with SFT labels."""
# Mock model
mock_model = Mock()
mock_outputs = Mock()
vocab_size = 1000
seq_len = 5
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
mock_model.return_value = mock_outputs
mock_model.training = True
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
mock_model, input_ids, labels=labels
)
# Check that loss is computed
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
# Check that SFT metrics were added
call_args = diffusion_trainer_instance.store_metrics.call_args[0][0]
assert "answer_ratio" in call_args
assert "avg_answer_length" in call_args
def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance):
"""Test loss computation when no tokens are masked."""
# Mock model
mock_model = Mock()
mock_outputs = Mock()
vocab_size = 1000
seq_len = 3
mock_outputs.logits = torch.randn(1, seq_len, vocab_size)
mock_model.return_value = mock_outputs
mock_model.training = True
# Only special tokens (which won't be masked)
input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)
loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
mock_model, input_ids
)
# Loss should be zero when no tokens are masked
assert loss.item() == 0.0
assert loss.requires_grad
def test_cache_special_token_ids(self, diffusion_trainer_instance):
"""Test caching of special token IDs."""
# Should cache BOS, EOS, PAD tokens
expected_tokens = {0, 1, 2} # pad, bos, eos
assert diffusion_trainer_instance._special_token_ids == expected_tokens
def test_cache_special_token_ids_no_tokenizer(self):
"""Test caching when no tokenizer is available."""
trainer = object.__new__(DiffusionTrainer) # Bypass __init__
trainer.processing_class = None
trainer._cache_special_token_ids()
assert trainer._special_token_ids == set()
def test_main_compute_loss_interface(self, diffusion_trainer_instance):
"""Test the main compute_loss interface."""
# Mock model
mock_model = Mock()
mock_outputs = Mock()
mock_outputs.logits = torch.randn(1, 5, 1000)
mock_model.return_value = mock_outputs
mock_model.training = True
inputs = {
"input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long),
"attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long),
"labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long),
}
# Test without return_outputs
loss = diffusion_trainer_instance.compute_loss(mock_model, inputs)
assert isinstance(loss, torch.Tensor)
# Test with return_outputs
loss, outputs = diffusion_trainer_instance.compute_loss(
mock_model, inputs, return_outputs=True
)
assert isinstance(loss, torch.Tensor)
assert outputs == mock_outputs
def test_missing_input_ids_raises_error(self, diffusion_trainer_instance):
"""Test that missing input_ids raises ValueError."""
mock_model = Mock()
inputs = {"attention_mask": torch.tensor([[1, 1, 1]])}
with pytest.raises(ValueError, match="input_ids is required"):
diffusion_trainer_instance.compute_loss(mock_model, inputs)