calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier

This commit is contained in:
Wing Lian
2023-08-10 17:16:01 -04:00
parent 57d9bf711c
commit a07f432d9c
4 changed files with 17 additions and 6 deletions

View File

@@ -17,7 +17,7 @@ except ImportError:
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from axolotl.monkeypatch.utils import get_cu_seqlens
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
def forward(
@@ -93,7 +93,7 @@ def forward(
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu_q_lens, max_s = get_cu_seqlens(key_padding_mask)
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
cu_q_lens = cu_q_lens.squeeze()
output = flash_attn_varlen_qkvpacked_func(

View File

@@ -2,12 +2,23 @@
utility helpers for distributed checks
"""
import torch.distributed as dist
from accelerate import Accelerator
accelerate = None # pylint: disable=invalid-name
def load_accelerate():
global accelerate # pylint: disable=global-statement
accelerate = Accelerator()
def is_distributed():
"""
Check if distributed training is initialized.
"""
global accelerate # pylint: disable=global-statement
if not accelerate:
accelerate = Accelerator()
return dist.is_available() and dist.is_initialized()

View File

@@ -202,13 +202,14 @@ class AxolotlTrainer(Trainer):
collate_fn=self.data_collator,
sampler=eval_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
)
return super().get_eval_dataloader(eval_dataset)
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
# outputs = model(**inputs)
@@ -321,7 +322,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
),
sampler=sampler,
packing_efficiency_estimate=cfg.sample_packing_eff_est,
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier,
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
data_loader_len = data_loader.len_w_stats()
@@ -466,7 +467,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier or 1,
sample_packing_seq_len_multiplier=cfg.micro_batch_size or 1,
**training_arguments_kwargs,
)

View File

@@ -36,7 +36,6 @@ class TestExpandMask(unittest.TestCase):
],
]
)
print(repr(_expand_mask(mask, dtype)))
# Check that the output matches the expected output
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))