This commit is contained in:
Dan Saunders
2025-08-14 01:53:24 -04:00
parent d8b63804bc
commit 0a9341acde
2 changed files with 5 additions and 9 deletions

View File

@@ -1,8 +1,4 @@
""" """Diffusion LM training plugin init."""
Diffusion LM training plugin for Axolotl.
This plugin enables diffusion language model training using the LLaDA approach.
"""
from .args import DiffusionArgs from .args import DiffusionArgs
from .plugin import DiffusionPlugin from .plugin import DiffusionPlugin

View File

@@ -1,4 +1,4 @@
"""Configuration arguments for diffusion LM training.""" """Config args for diffusion LM training."""
from typing import Literal from typing import Literal
@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field
class DiffusionArgs(BaseModel): class DiffusionArgs(BaseModel):
"""Arguments for diffusion LM training plugin.""" """Arguments for diffusion LM training plugin."""
# Noise schedule configuration # Noise schedule config
noise_schedule: Literal["linear", "cosine"] = Field( noise_schedule: Literal["linear", "cosine"] = Field(
default="linear", description="Type of noise schedule for diffusion training" default="linear", description="Type of noise schedule for diffusion training"
) )
@@ -28,7 +28,7 @@ class DiffusionArgs(BaseModel):
default=1000, ge=1, description="Number of diffusion timesteps" default=1000, ge=1, description="Number of diffusion timesteps"
) )
# Forward process parameters # Forward process config
eps: float = Field( eps: float = Field(
default=1e-3, default=1e-3,
ge=0.0, ge=0.0,
@@ -36,7 +36,7 @@ class DiffusionArgs(BaseModel):
description="Epsilon value for minimum masking probability in forward process", description="Epsilon value for minimum masking probability in forward process",
) )
# Training configuration # Training config
importance_weighting: bool = Field( importance_weighting: bool = Field(
default=True, default=True,
description="Apply importance weighting to loss based on masking probability", description="Apply importance weighting to loss based on masking probability",