# EBFT Strided: Full-parameter Llama-3.1-8B on SwallowCode, 100 steps # Feature network is CPU-offloaded to fit in single 32GB GPU # # Run: CUDA_VISIBLE_DEVICES=0 python -m axolotl.cli.train examples/ebft/llama-8b-ebft-strided-fft.yaml base_model: meta-llama/Llama-3.1-8B rl: ebft ebft: mode: strided stride: 8 context_length: 8 generate_max_len: 8 n_samples_per_prompt: 4 temperature: 0.6 top_p: 1.0 feature_layers: [0.25, 0.5, 0.75] embed_method: last_token use_whitening: true alignment_coef: 1.0 diversity_coef: 1.0 rl_coef: 1.0 ce_coef: 0.0 advantage_estimator: rloo datasets: - path: sjelassi/swallow_code_20m type: ebft_pretrain.transform split: train[:5000] sequence_len: 1024 micro_batch_size: 1 gradient_accumulation_steps: 4 num_epochs: 1 max_steps: 100 learning_rate: 1.0e-6 optimizer: adamw_torch_fused lr_scheduler: cosine warmup_steps: 10 weight_decay: 0.01 bf16: auto flash_attention: false # strided EBFT uses flex_attention at runtime gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false special_tokens: pad_token: "<|end_of_text|>" val_set_size: 0.0 output_dir: ./outputs/ebft-llama8b-strided wandb_project: ebft wandb_name: llama8b-strided-fft-100steps logging_steps: 1 save_steps: 50