# EBFT Strided: LoRA Llama-3.2-3B on SwallowCode, 100 steps # Actor on GPU 0, frozen feature network on GPU 1 # # Run: CUDA_VISIBLE_DEVICES=0,1 python -m axolotl.cli.train examples/ebft/llama-3b-ebft-strided-fft.yaml base_model: meta-llama/Llama-3.2-3B rl: ebft ebft: mode: strided stride: 8 context_length: 8 generate_max_len: 8 n_samples_per_prompt: 4 temperature: 0.6 top_p: 1.0 feature_layers: [0.25, 0.5, 0.75] embed_method: last_token use_whitening: true alignment_coef: 1.0 diversity_coef: 1.0 rl_coef: 1.0 ce_coef: 0.0 # paper recommends 0.03 for mixed objective; 0.1 causes CE to dominate advantage_estimator: rloo datasets: - path: sjelassi/swallow_code_20m type: ebft_pretrain.transform split: train[:5000] sequence_len: 1024 micro_batch_size: 1 gradient_accumulation_steps: 4 num_epochs: 1 max_steps: 100 learning_rate: 1.0e-5 optimizer: adamw_torch_fused lr_scheduler: cosine warmup_steps: 10 weight_decay: 0.01 adapter: lora lora_r: 32 lora_alpha: 64 lora_dropout: 0.05 lora_target_linear: true bf16: auto torch_dtype: bfloat16 flash_attention: false gradient_checkpointing: true torch_compile: true gradient_checkpointing_kwargs: use_reentrant: true ddp: false device_map: "": 0 special_tokens: pad_token: "<|end_of_text|>" val_set_size: 0.0 output_dir: ./outputs/ebft-llama3b-strided wandb_project: ebft wandb_name: llama3b-strided-lora-100steps logging_steps: 1 save_steps: 50