diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 179cfc5fa..313c68c57 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -94,6 +94,7 @@ def forward( else: qkv = rearrange(qkv, "b s ... -> (b s) ...") cu_q_lens, max_s = get_cu_seqlens(key_padding_mask) + cu_q_lens = cu_q_lens.squeeze() output = flash_attn_varlen_qkvpacked_func( qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 9af8674d3..3a6ba7591 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -299,6 +299,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): / cfg.sample_packing_eff_est / 2048 // cfg.batch_size + // int(os.environ.get("WORLD_SIZE", 1)) ) - 1 )