nits
This commit is contained in:
@@ -1,8 +1,4 @@
|
|||||||
"""
|
"""Diffusion LM training plugin init."""
|
||||||
Diffusion LM training plugin for Axolotl.
|
|
||||||
|
|
||||||
This plugin enables diffusion language model training using the LLaDA approach.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .args import DiffusionArgs
|
from .args import DiffusionArgs
|
||||||
from .plugin import DiffusionPlugin
|
from .plugin import DiffusionPlugin
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Configuration arguments for diffusion LM training."""
|
"""Config args for diffusion LM training."""
|
||||||
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field
|
|||||||
class DiffusionArgs(BaseModel):
|
class DiffusionArgs(BaseModel):
|
||||||
"""Arguments for diffusion LM training plugin."""
|
"""Arguments for diffusion LM training plugin."""
|
||||||
|
|
||||||
# Noise schedule configuration
|
# Noise schedule config
|
||||||
noise_schedule: Literal["linear", "cosine"] = Field(
|
noise_schedule: Literal["linear", "cosine"] = Field(
|
||||||
default="linear", description="Type of noise schedule for diffusion training"
|
default="linear", description="Type of noise schedule for diffusion training"
|
||||||
)
|
)
|
||||||
@@ -28,7 +28,7 @@ class DiffusionArgs(BaseModel):
|
|||||||
default=1000, ge=1, description="Number of diffusion timesteps"
|
default=1000, ge=1, description="Number of diffusion timesteps"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Forward process parameters
|
# Forward process config
|
||||||
eps: float = Field(
|
eps: float = Field(
|
||||||
default=1e-3,
|
default=1e-3,
|
||||||
ge=0.0,
|
ge=0.0,
|
||||||
@@ -36,7 +36,7 @@ class DiffusionArgs(BaseModel):
|
|||||||
description="Epsilon value for minimum masking probability in forward process",
|
description="Epsilon value for minimum masking probability in forward process",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training configuration
|
# Training config
|
||||||
importance_weighting: bool = Field(
|
importance_weighting: bool = Field(
|
||||||
default=True,
|
default=True,
|
||||||
description="Apply importance weighting to loss based on masking probability",
|
description="Apply importance weighting to loss based on masking probability",
|
||||||
|
|||||||
Reference in New Issue
Block a user