batch api HF adapter for ring-flash-attn; cleanup and improvements (#2520)
* batch api HF adapter for ring-flash-attn; cleanup and improvements * update * adding all batch ring-flash-attn methods via single adapter * removing pad_to_sequence_len=False for now * fix * updating docs to include batch SP * review comments * fixes for batch API funcs, simplify * fixes * fix * updates * add batch_zigzag smoke test
This commit is contained in:
@@ -27,6 +27,9 @@ To enable sequence parallelism, add the following to your configuration file:
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to
|
||||
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
|
||||
ring_attn_func:
|
||||
```
|
||||
|
||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
Reference in New Issue
Block a user