From 234b7b31265312233f2036f3169b3b0c9c739774 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Sat, 16 Aug 2025 00:14:44 +0000 Subject: [PATCH] nits --- examples/llama-3/diffusion-3.2-1b-pretrain.yaml | 2 +- examples/llama-3/diffusion-3.2-1b-sft.yaml | 2 +- src/axolotl/core/trainers/base.py | 2 +- src/axolotl/integrations/diffusion/README.md | 2 +- src/axolotl/integrations/diffusion/args.py | 2 +- src/axolotl/integrations/diffusion/trainer.py | 2 ++ 6 files changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/llama-3/diffusion-3.2-1b-pretrain.yaml b/examples/llama-3/diffusion-3.2-1b-pretrain.yaml index 965e248eb..ca0271ba7 100644 --- a/examples/llama-3/diffusion-3.2-1b-pretrain.yaml +++ b/examples/llama-3/diffusion-3.2-1b-pretrain.yaml @@ -15,7 +15,7 @@ plugins: noise_schedule: "cosine" min_mask_ratio: 0.15 max_mask_ratio: 0.85 -num_diffusion_steps: 2000 +num_diffusion_steps: 128 eps: 5e-4 importance_weighting: true diff --git a/examples/llama-3/diffusion-3.2-1b-sft.yaml b/examples/llama-3/diffusion-3.2-1b-sft.yaml index 30c2504b4..019fefbb3 100644 --- a/examples/llama-3/diffusion-3.2-1b-sft.yaml +++ b/examples/llama-3/diffusion-3.2-1b-sft.yaml @@ -12,7 +12,7 @@ plugins: noise_schedule: "linear" min_mask_ratio: 0.1 max_mask_ratio: 0.9 -num_diffusion_steps: 1000 +num_diffusion_steps: 128 eps: 1e-3 importance_weighting: true diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 77a9cc83a..c433b2a39 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -591,7 +591,7 @@ class AxolotlTrainer( logs[key] = values.sum().item() else: raise NotImplementedError( - "Metric reduction must be one of [mean, min, max]" + "Metric reduction must be one of [mean, min, max, sum]" ) logs[key] = round(logs[key], 4) diff --git a/src/axolotl/integrations/diffusion/README.md b/src/axolotl/integrations/diffusion/README.md index ce2b0c8f7..7a1e909a6 100644 --- a/src/axolotl/integrations/diffusion/README.md +++ b/src/axolotl/integrations/diffusion/README.md @@ -33,7 +33,7 @@ plugins: noise_schedule: "linear" # or "cosine" min_mask_ratio: 0.1 max_mask_ratio: 0.9 -num_diffusion_steps: 1000 +num_diffusion_steps: 128 eps: 1e-3 importance_weighting: true diff --git a/src/axolotl/integrations/diffusion/args.py b/src/axolotl/integrations/diffusion/args.py index f01db087c..0e27e7362 100644 --- a/src/axolotl/integrations/diffusion/args.py +++ b/src/axolotl/integrations/diffusion/args.py @@ -25,7 +25,7 @@ class DiffusionArgs(BaseModel): description="Maximum masking ratio for diffusion noise schedule", ) num_diffusion_steps: int = Field( - default=1000, ge=1, description="Number of diffusion timesteps" + default=128, ge=1, description="Number of diffusion timesteps" ) # Forward process config diff --git a/src/axolotl/integrations/diffusion/trainer.py b/src/axolotl/integrations/diffusion/trainer.py index ed81fd029..9bf000b6d 100644 --- a/src/axolotl/integrations/diffusion/trainer.py +++ b/src/axolotl/integrations/diffusion/trainer.py @@ -65,6 +65,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors self._special_token_ids = special_tokens + @torch.compile def _forward_process( self, input_ids: torch.Tensor, @@ -120,6 +121,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors return noisy_batch, masked_indices, p_mask + @torch.compile def _create_bidirectional_attention_mask( self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None ) -> torch.Tensor: