diff --git a/examples/llama-3/fft-8b-tp.yml b/examples/llama-3/fft-8b-tp.yml index c61e033ee..e80380053 100644 --- a/examples/llama-3/fft-8b-tp.yml +++ b/examples/llama-3/fft-8b-tp.yml @@ -34,7 +34,7 @@ bf16: auto fp16: tf32: false -tensor_parallel: true +tensor_parallel: 'auto' gradient_checkpointing: true gradient_checkpointing_kwargs: