nits
This commit is contained in:
@@ -15,7 +15,7 @@ plugins:
|
||||
noise_schedule: "cosine"
|
||||
min_mask_ratio: 0.15
|
||||
max_mask_ratio: 0.85
|
||||
num_diffusion_steps: 2000
|
||||
num_diffusion_steps: 128
|
||||
eps: 5e-4
|
||||
importance_weighting: true
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ plugins:
|
||||
noise_schedule: "linear"
|
||||
min_mask_ratio: 0.1
|
||||
max_mask_ratio: 0.9
|
||||
num_diffusion_steps: 1000
|
||||
num_diffusion_steps: 128
|
||||
eps: 1e-3
|
||||
importance_weighting: true
|
||||
|
||||
|
||||
@@ -591,7 +591,7 @@ class AxolotlTrainer(
|
||||
logs[key] = values.sum().item()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Metric reduction must be one of [mean, min, max]"
|
||||
"Metric reduction must be one of [mean, min, max, sum]"
|
||||
)
|
||||
|
||||
logs[key] = round(logs[key], 4)
|
||||
|
||||
@@ -33,7 +33,7 @@ plugins:
|
||||
noise_schedule: "linear" # or "cosine"
|
||||
min_mask_ratio: 0.1
|
||||
max_mask_ratio: 0.9
|
||||
num_diffusion_steps: 1000
|
||||
num_diffusion_steps: 128
|
||||
eps: 1e-3
|
||||
importance_weighting: true
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ class DiffusionArgs(BaseModel):
|
||||
description="Maximum masking ratio for diffusion noise schedule",
|
||||
)
|
||||
num_diffusion_steps: int = Field(
|
||||
default=1000, ge=1, description="Number of diffusion timesteps"
|
||||
default=128, ge=1, description="Number of diffusion timesteps"
|
||||
)
|
||||
|
||||
# Forward process config
|
||||
|
||||
@@ -65,6 +65,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
|
||||
self._special_token_ids = special_tokens
|
||||
|
||||
@torch.compile
|
||||
def _forward_process(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -120,6 +121,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
|
||||
return noisy_batch, masked_indices, p_mask
|
||||
|
||||
@torch.compile
|
||||
def _create_bidirectional_attention_mask(
|
||||
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user