This commit is contained in:
Dan Saunders
2025-08-16 00:14:44 +00:00
parent e19be0c2d9
commit 234b7b3126
6 changed files with 7 additions and 5 deletions

View File

@@ -15,7 +15,7 @@ plugins:
noise_schedule: "cosine"
min_mask_ratio: 0.15
max_mask_ratio: 0.85
num_diffusion_steps: 2000
num_diffusion_steps: 128
eps: 5e-4
importance_weighting: true

View File

@@ -12,7 +12,7 @@ plugins:
noise_schedule: "linear"
min_mask_ratio: 0.1
max_mask_ratio: 0.9
num_diffusion_steps: 1000
num_diffusion_steps: 128
eps: 1e-3
importance_weighting: true

View File

@@ -591,7 +591,7 @@ class AxolotlTrainer(
logs[key] = values.sum().item()
else:
raise NotImplementedError(
"Metric reduction must be one of [mean, min, max]"
"Metric reduction must be one of [mean, min, max, sum]"
)
logs[key] = round(logs[key], 4)

View File

@@ -33,7 +33,7 @@ plugins:
noise_schedule: "linear" # or "cosine"
min_mask_ratio: 0.1
max_mask_ratio: 0.9
num_diffusion_steps: 1000
num_diffusion_steps: 128
eps: 1e-3
importance_weighting: true

View File

@@ -25,7 +25,7 @@ class DiffusionArgs(BaseModel):
description="Maximum masking ratio for diffusion noise schedule",
)
num_diffusion_steps: int = Field(
default=1000, ge=1, description="Number of diffusion timesteps"
default=128, ge=1, description="Number of diffusion timesteps"
)
# Forward process config

View File

@@ -65,6 +65,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
self._special_token_ids = special_tokens
@torch.compile
def _forward_process(
self,
input_ids: torch.Tensor,
@@ -120,6 +121,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
return noisy_batch, masked_indices, p_mask
@torch.compile
def _create_bidirectional_attention_mask(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
) -> torch.Tensor: