diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 313c68c57..7a1118056 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -17,7 +17,7 @@ except ImportError: from transformers.models.llama.modeling_llama import apply_rotary_pos_emb -from axolotl.monkeypatch.utils import get_cu_seqlens +from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids def forward( @@ -93,7 +93,7 @@ def forward( output = rearrange(output, "(b s) ... -> b s ...", b=bsz) else: qkv = rearrange(qkv, "b s ... -> (b s) ...") - cu_q_lens, max_s = get_cu_seqlens(key_padding_mask) + cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids) cu_q_lens = cu_q_lens.squeeze() output = flash_attn_varlen_qkvpacked_func( diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 855c95304..345b9640c 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -2,12 +2,23 @@ utility helpers for distributed checks """ import torch.distributed as dist +from accelerate import Accelerator + +accelerate = None # pylint: disable=invalid-name + + +def load_accelerate(): + global accelerate # pylint: disable=global-statement + accelerate = Accelerator() def is_distributed(): """ Check if distributed training is initialized. """ + global accelerate # pylint: disable=global-statement + if not accelerate: + accelerate = Accelerator() return dist.is_available() and dist.is_initialized() diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 3a6ba7591..82380183c 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -202,13 +202,14 @@ class AxolotlTrainer(Trainer): collate_fn=self.data_collator, sampler=eval_sampler, packing_efficiency_estimate=self.args.sample_packing_efficiency, - sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, + sample_packing_seq_len_multiplier=self.args.eval_batch_size, device_count=int(os.environ.get("WORLD_SIZE", 1)), ) ) return super().get_eval_dataloader(eval_dataset) def compute_loss(self, model, inputs, return_outputs=False): + # use one's weighted cross entropy loss calc # if self.args.sample_packing: # labels = inputs.pop("labels") # outputs = model(**inputs) @@ -321,7 +322,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): ), sampler=sampler, packing_efficiency_estimate=cfg.sample_packing_eff_est, - sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier, + sample_packing_seq_len_multiplier=cfg.micro_batch_size, device_count=int(os.environ.get("WORLD_SIZE", 1)), ) data_loader_len = data_loader.len_w_stats() @@ -466,7 +467,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ else "cosine", weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, sample_packing=cfg.sample_packing if cfg.sample_packing else False, - sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier or 1, + sample_packing_seq_len_multiplier=cfg.micro_batch_size or 1, **training_arguments_kwargs, ) diff --git a/tests/test_expand_mask.py b/tests/test_expand_mask.py index 885d7f0d7..01241c295 100644 --- a/tests/test_expand_mask.py +++ b/tests/test_expand_mask.py @@ -36,7 +36,6 @@ class TestExpandMask(unittest.TestCase): ], ] ) - print(repr(_expand_mask(mask, dtype))) # Check that the output matches the expected output self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))