diff --git a/docs/docker.qmd b/docs/docker.qmd index e208d3222..d665eaf5b 100644 --- a/docs/docker.qmd +++ b/docs/docker.qmd @@ -8,6 +8,10 @@ format: This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai). +::: {.callout-important} +For Blackwell GPUs, please use the tags with Pytorch 2.7.0 and CUDA 12.8. +::: + ## Base The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more. diff --git a/docs/installation.qmd b/docs/installation.qmd index 0cf5ffceb..b429992b6 100644 --- a/docs/installation.qmd +++ b/docs/installation.qmd @@ -25,6 +25,10 @@ Please make sure to have Pytorch installed before installing Axolotl in your loc Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/) ::: +::: {.callout-important} +For Blackwell GPUs, please use Pytorch 2.7.0 and CUDA 12.8. +::: + ### PyPI Installation (Recommended) {#sec-pypi} ```{.bash} @@ -72,6 +76,10 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \ ``` ::: +::: {.callout-important} +For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.7.0` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.7.0`. +::: + Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available. ## Cloud Environments {#sec-cloud} diff --git a/src/axolotl/monkeypatch/ring_attn/patch.py b/src/axolotl/monkeypatch/ring_attn/patch.py index 4329d9f13..7d733cfc1 100644 --- a/src/axolotl/monkeypatch/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/ring_attn/patch.py @@ -51,6 +51,8 @@ NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1 def get_ring_attn_group() -> dist.ProcessGroup: """Getter for ring attention group on this rank.""" + if RING_ATTN_GROUP is None: + raise RuntimeError("register_ring_attn() not yet called") return RING_ATTN_GROUP @@ -69,8 +71,8 @@ def register_ring_attn( Args: sequence_parallel_degree: Sequence parallelism factor. - heads_k_stride: Sequence parallelism K head stride size. Passed - through to `ring_flash_attn.substitute_hf_flash_attn`. + heads_k_stride: Sequence parallelism K head stride size. Passed through to + `varlen_llama3` `ring_flash_attn` implementation. ring_attn_func: `ring_flash_attn` ring attention implemention. If sample packing is enabled, it must be a `varlen` function; otherwise, it must be a `batch` function. diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 638cee559..047a66e94 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -424,6 +424,20 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): LOG.debug(f"Should train: {should_train}") + # turn not trainable, skip having to find the turn indices + # unless last turn and train_on_eos/train_on_eot is all + if not should_train and ( + self.train_on_eos != "all" and self.train_on_eot != "all" + ): + if index == len(turns) - 1: + LOG.warning( + "Last turn is not trainable, skipping having to find the turn indices. " + "This may cause incorrect last EOT/EOS token to be unmasked." + "This is likely a dataset design issue. Please ensure last turn is trainable." + ) + + continue + turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index) LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}") diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 3aaba84ff..f946cdb2a 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -210,6 +210,7 @@ def execute_training( sequence_parallel_degree=cfg.sequence_parallel_degree, gradient_accumulation_steps=cfg.gradient_accumulation_steps, ring_attn_func=cfg.ring_attn_func, + heads_k_stride=cfg.heads_k_stride, ) ) diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index 6e4f9bada..2ae93acad 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -12,6 +12,9 @@ from transformers.utils import ModelOutput from axolotl.monkeypatch.ring_attn.patch import ( get_ring_attn_group, + patch_prepare_data_loader, + patch_prepare_device_mesh, + register_ring_attn, update_ring_attn_params, ) from axolotl.utils.schemas.enums import RingAttnFunc @@ -169,6 +172,8 @@ class SequenceParallelContextManager: sequence_parallel_degree: Number of processes to split sequences over. gradient_accumulation_steps: Number of steps to accumulate gradients over. ring_attn_func: Which ring attention function to use. Currently unused. + heads_k_stride: Sequence parallelism K head stride size. Passed through to + `varlen_llama3` `ring_flash_attn` implementation. """ def __init__( @@ -177,14 +182,17 @@ class SequenceParallelContextManager: sequence_parallel_degree: int, gradient_accumulation_steps: int, ring_attn_func: RingAttnFunc, + heads_k_stride: int | None, ): self.models = models self.sequence_parallel_degree = sequence_parallel_degree self.gradient_accumulation_steps = gradient_accumulation_steps self.ring_attn_func = ring_attn_func - self.process_group = get_ring_attn_group() + self.heads_k_stride = heads_k_stride + self._register_ring_attn() - # Initialize sequence parallel group details + # Set distributed info for local rank + self.process_group = get_ring_attn_group() self.local_rank = dist.get_rank(self.process_group) self.local_world_size = dist.get_world_size(self.process_group) @@ -205,6 +213,33 @@ class SequenceParallelContextManager: ) def __enter__(self): + self._register_model_hooks() + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Remove all hooks + for handle in self.hook_handles: + handle.remove() + self.hook_handles = [] + + # TODO(djsaunde): Un-patch attention and accelerate functions (low priority) + + def _register_ring_attn(self): + # Initialize ring attn for sequence parallelism + register_ring_attn( + sequence_parallel_degree=self.sequence_parallel_degree, + heads_k_stride=self.heads_k_stride, + ring_attn_func=self.ring_attn_func, + ) + + # Patches for accelerate functionality + patch_prepare_data_loader() + patch_prepare_device_mesh( + sequence_parallel_degree=self.sequence_parallel_degree + ) + + def _register_model_hooks(self): # Forward pre-hook to apply sequence parallelism def sequence_parallel_pre_hook(_, args, kwargs): # Get parameter names from the model's forward function @@ -230,7 +265,7 @@ class SequenceParallelContextManager: # Forward post-hook to gather outputs def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput: # Gather the sharded outputs - output = self.gather_outputs(output) + output = self._gather_outputs(output) # Remove padding if it was added if self.pad_len > 0: @@ -253,15 +288,7 @@ class SequenceParallelContextManager: model.register_forward_hook(sequence_parallel_post_hook) ) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - # Remove all hooks - for handle in self.hook_handles: - handle.remove() - self.hook_handles = [] - - def gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: + def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: """Gather sharded outputs from all ranks and reconstruct the full tensor.""" for key, value in output.items(): if isinstance(value, torch.Tensor) and value.dim() > 1: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6236f78e8..cd7499869 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -59,7 +59,6 @@ from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, patch_for_multipack, ) -from axolotl.monkeypatch.ring_attn.patch import get_ring_attn_group from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import get_chat_template_from_config @@ -681,27 +680,6 @@ class ModelLoader: patch_self_attn_lora(self.cfg) - if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1: - from axolotl.monkeypatch.ring_attn import ( - patch_prepare_data_loader, - patch_prepare_device_mesh, - register_ring_attn, - ) - - # Initialize ring attn for sequence parallelism. This must be done after - # model init but before the first forward pass, since it modifies flash - # attn to use ring comm for SP training across multiple GPUs. - if get_ring_attn_group() is None: # If already set, this is already patched - register_ring_attn( - sequence_parallel_degree=self.cfg.sequence_parallel_degree, - heads_k_stride=self.cfg.heads_k_stride, - ring_attn_func=self.cfg.ring_attn_func, - ) - patch_prepare_data_loader() - patch_prepare_device_mesh( - sequence_parallel_degree=self.cfg.sequence_parallel_degree - ) - def patch_attention(self) -> None: if hasattr(self.model_config, "model_type"): if self.model_config.model_type == "mllama" and self.cfg.flash_attention: diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py index 35d8060b1..69a904548 100644 --- a/tests/core/test_trainer_builder.py +++ b/tests/core/test_trainer_builder.py @@ -340,3 +340,27 @@ class TestHFCausalTrainerBuilder: # SFT specific assert training_arguments.sample_packing is False assert training_arguments.eval_sample_packing is False + + +class TestTrainerClsPlugin: + """ + TestCase class for trainer builder with plugin + """ + + def test_trainer_cls_is_not_none_with_plugin(self, cfg, model, tokenizer): + """ + Test that the trainer cls is not none with plugin + + Fixes #2693 + """ + cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"] + cfg.rl = RLType.KTO + + # Expected AttributeError as we don't pass regular model configs to RL trainer builder + # If it throws `TypeError: None is not a callable object`, trainer_cls could be None + with pytest.raises( + AttributeError, match=r".*'tuple' object has no attribute 'config'.*" + ): + builder = HFRLTrainerBuilder(cfg, model, tokenizer) + + builder.build(100) diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 83faa779f..2b4d11b30 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -84,16 +84,16 @@ class TestRingAttention: def test_get_ring_attn_group_no_registration( self, mock_world_size, mock_rank, partial_state ): - """Test that get_ring_attn_group returns None when no group has been registered.""" + """Test that get_ring_attn_group raises RuntimeError when no group has been registered.""" # Setup mocks mock_world_size.return_value = 4 mock_rank.return_value = 0 - # Get the group without registration - group = get_ring_attn_group() - - # Verify that None was returned - assert group is None + # Verify that RuntimeError is raised when no group is registered + with pytest.raises( + RuntimeError, match="register_ring_attn\\(\\) not yet called" + ): + get_ring_attn_group() @patch("torch.distributed.new_group") @patch("torch.distributed.get_rank") @@ -323,8 +323,11 @@ class TestApplySequenceParallelism: lambda **kwargs: None, ) - def test_world_size_one(self, sequence_parallel_batch): + @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") + def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch): """Test that function returns original batch when world size is 1.""" + mock_get_ring_attn_group.return_value = 0 + result, _, _ = apply_sequence_parallelism( batch=sequence_parallel_batch, local_rank=0, @@ -336,8 +339,11 @@ class TestApplySequenceParallelism: # Should return the original batch unchanged assert result == sequence_parallel_batch - def test_batch_ring_rank0(self, sequence_parallel_batch): + @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") + def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch): """Test BATCH_RING sharding for rank 0 in a 2-process group.""" + mock_get_ring_attn_group.return_value = 0 + batch = sequence_parallel_batch seq_len = batch["input_ids"].size(1) @@ -359,8 +365,11 @@ class TestApplySequenceParallelism: result["position_ids"], batch["position_ids"][:, : seq_len // 2] ) - def test_batch_ring_rank1(self, sequence_parallel_batch): + @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") + def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch): """Test BATCH_RING sharding for rank 1 in a 2-process group.""" + mock_get_ring_attn_group.return_value = 0 + batch = sequence_parallel_batch seq_len = batch["input_ids"].size(1) original_input_ids = batch["input_ids"].clone() @@ -419,8 +428,13 @@ class TestApplySequenceParallelism: # assert torch.equal(result_rank0["input_ids"], rank0_expected) # assert torch.equal(result_rank1["input_ids"], rank1_expected) - def test_partial_application(self, sequence_parallel_batch): + @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") + def test_partial_application( + self, mock_get_ring_attn_group, sequence_parallel_batch + ): """Test that we can create a partially applied version of the function.""" + mock_get_ring_attn_group.return_value = 0 + batch = sequence_parallel_batch original_input_ids = batch["input_ids"].clone()