fix steps calculation

This commit is contained in:
Wing Lian
2023-08-08 05:13:21 -04:00
parent 21d307b15b
commit 6ec76ddb4c
2 changed files with 2 additions and 0 deletions

View File

@@ -94,6 +94,7 @@ def forward(
else:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu_q_lens, max_s = get_cu_seqlens(key_padding_mask)
cu_q_lens = cu_q_lens.squeeze()
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True

View File

@@ -299,6 +299,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
/ cfg.sample_packing_eff_est
/ 2048
// cfg.batch_size
// int(os.environ.get("WORLD_SIZE", 1))
)
- 1
)