nits
This commit is contained in:
@@ -1,8 +1,4 @@
|
||||
"""
|
||||
Diffusion LM training plugin for Axolotl.
|
||||
|
||||
This plugin enables diffusion language model training using the LLaDA approach.
|
||||
"""
|
||||
"""Diffusion LM training plugin init."""
|
||||
|
||||
from .args import DiffusionArgs
|
||||
from .plugin import DiffusionPlugin
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Configuration arguments for diffusion LM training."""
|
||||
"""Config args for diffusion LM training."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field
|
||||
class DiffusionArgs(BaseModel):
|
||||
"""Arguments for diffusion LM training plugin."""
|
||||
|
||||
# Noise schedule configuration
|
||||
# Noise schedule config
|
||||
noise_schedule: Literal["linear", "cosine"] = Field(
|
||||
default="linear", description="Type of noise schedule for diffusion training"
|
||||
)
|
||||
@@ -28,7 +28,7 @@ class DiffusionArgs(BaseModel):
|
||||
default=1000, ge=1, description="Number of diffusion timesteps"
|
||||
)
|
||||
|
||||
# Forward process parameters
|
||||
# Forward process config
|
||||
eps: float = Field(
|
||||
default=1e-3,
|
||||
ge=0.0,
|
||||
@@ -36,7 +36,7 @@ class DiffusionArgs(BaseModel):
|
||||
description="Epsilon value for minimum masking probability in forward process",
|
||||
)
|
||||
|
||||
# Training configuration
|
||||
# Training config
|
||||
importance_weighting: bool = Field(
|
||||
default=True,
|
||||
description="Apply importance weighting to loss based on masking probability",
|
||||
|
||||
Reference in New Issue
Block a user