# EBFT Strided Structured Mode: For structured (prompt, completion) data # Uses strided block-parallel generation on completion spans — no vLLM needed. # # Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-strided-structured.yaml base_model: meta-llama/Llama-3.2-1B rl: ebft ebft: mode: strided # strided block-parallel generation stride: 8 # tokens between anchor points context_length: 8 # context window per block generate_max_len: 8 # tokens to generate per block n_samples_per_prompt: 4 # rollouts per document temperature: 0.6 top_p: 1.0 feature_layers: [0.25, 0.5, 0.75] embed_method: last_token use_whitening: true alignment_coef: 1.0 diversity_coef: 1.0 rl_coef: 1.0 ce_coef: 0.03 # small CE weight for structured data advantage_estimator: rloo min_completion_prefix: 8 # skip anchors too close to prompt boundary datasets: - path: nvidia/OpenCodeInstruct type: ebft_strided_structured.transform split: train[:1%] sequence_len: 2048 micro_batch_size: 1 gradient_accumulation_steps: 2 num_epochs: 1 # max_steps: 10 learning_rate: 1.0e-6 optimizer: adamw_torch_fused lr_scheduler: cosine warmup_steps: 5 adapter: lora lora_r: 16 lora_alpha: 32 lora_dropout: 0.05 lora_target_linear: true bf16: auto flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime flex_attention: true # fused flex_attention kernel compiles itself; don't set torch_compile: true # (full-model compile conflicts with gradient checkpointing + flex_attention) gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: true # required for flex_attention (non-reentrant causes CheckpointError) special_tokens: pad_token: "<|end_of_text|>" val_set_size: 0.0 output_dir: ./outputs/ebft-strided-structured wandb_project: ebft logging_steps: 1 save_steps: 100