calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier
This commit is contained in:
@@ -17,7 +17,7 @@ except ImportError:
|
||||
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
|
||||
def forward(
|
||||
@@ -93,7 +93,7 @@ def forward(
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
else:
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
cu_q_lens, max_s = get_cu_seqlens(key_padding_mask)
|
||||
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_q_lens = cu_q_lens.squeeze()
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
|
||||
@@ -2,12 +2,23 @@
|
||||
utility helpers for distributed checks
|
||||
"""
|
||||
import torch.distributed as dist
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerate = None # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def load_accelerate():
|
||||
global accelerate # pylint: disable=global-statement
|
||||
accelerate = Accelerator()
|
||||
|
||||
|
||||
def is_distributed():
|
||||
"""
|
||||
Check if distributed training is initialized.
|
||||
"""
|
||||
global accelerate # pylint: disable=global-statement
|
||||
if not accelerate:
|
||||
accelerate = Accelerator()
|
||||
return dist.is_available() and dist.is_initialized()
|
||||
|
||||
|
||||
|
||||
@@ -202,13 +202,14 @@ class AxolotlTrainer(Trainer):
|
||||
collate_fn=self.data_collator,
|
||||
sampler=eval_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# use one's weighted cross entropy loss calc
|
||||
# if self.args.sample_packing:
|
||||
# labels = inputs.pop("labels")
|
||||
# outputs = model(**inputs)
|
||||
@@ -321,7 +322,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
),
|
||||
sampler=sampler,
|
||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier,
|
||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
data_loader_len = data_loader.len_w_stats()
|
||||
@@ -466,7 +467,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
else "cosine",
|
||||
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
||||
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
||||
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier or 1,
|
||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size or 1,
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -36,7 +36,6 @@ class TestExpandMask(unittest.TestCase):
|
||||
],
|
||||
]
|
||||
)
|
||||
print(repr(_expand_mask(mask, dtype)))
|
||||
# Check that the output matches the expected output
|
||||
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user