From 6ec76ddb4cdcac5b66ecc0e42f6202c8290da209 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 8 Aug 2023 05:13:21 -0400 Subject: [PATCH] fix steps calculation --- src/axolotl/monkeypatch/llama_attn_hijack_flash.py | 1 + src/axolotl/utils/trainer.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 179cfc5fa..313c68c57 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -94,6 +94,7 @@ def forward( else: qkv = rearrange(qkv, "b s ... -> (b s) ...") cu_q_lens, max_s = get_cu_seqlens(key_padding_mask) + cu_q_lens = cu_q_lens.squeeze() output = flash_attn_varlen_qkvpacked_func( qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 9af8674d3..3a6ba7591 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -299,6 +299,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): / cfg.sample_packing_eff_est / 2048 // cfg.batch_size + // int(os.environ.get("WORLD_SIZE", 1)) ) - 1 )