fix steps calculation
This commit is contained in:
@@ -94,6 +94,7 @@ def forward(
|
|||||||
else:
|
else:
|
||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
cu_q_lens, max_s = get_cu_seqlens(key_padding_mask)
|
cu_q_lens, max_s = get_cu_seqlens(key_padding_mask)
|
||||||
|
cu_q_lens = cu_q_lens.squeeze()
|
||||||
|
|
||||||
output = flash_attn_varlen_qkvpacked_func(
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||||
|
|||||||
@@ -299,6 +299,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|||||||
/ cfg.sample_packing_eff_est
|
/ cfg.sample_packing_eff_est
|
||||||
/ 2048
|
/ 2048
|
||||||
// cfg.batch_size
|
// cfg.batch_size
|
||||||
|
// int(os.environ.get("WORLD_SIZE", 1))
|
||||||
)
|
)
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user