diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dadac90c3..4c9268529 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ default_language_version: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: check-yaml - id: end-of-file-fixer @@ -23,7 +23,7 @@ repos: hooks: - id: flake8 - repo: https://github.com/pylint-dev/pylint - rev: v3.3.7 + rev: v3.3.8 hooks: - id: pylint - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/src/axolotl/cli/utils/train.py b/src/axolotl/cli/utils/train.py index 6fb99570b..16ec62440 100644 --- a/src/axolotl/cli/utils/train.py +++ b/src/axolotl/cli/utils/train.py @@ -124,6 +124,9 @@ def launch_training( _launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec) elif launcher == "python": _launch_python_training(cfg_file, kwargs) + elif launcher is None: + # handle ray train launch + _launch_python_training(cfg_file, kwargs) def _launch_cloud_training( diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 795fc3e37..f1ca3c725 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -73,9 +73,6 @@ class PatchManager: self._apply_voxtral_patches() def _apply_transformers_patches(self): - from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import ( - patch_prepare_from_posids, - ) from axolotl.monkeypatch.transformers.trainer_loss_calc import ( patch_evaluation_loop, patch_maybe_log_save_evaluate, @@ -87,7 +84,6 @@ class PatchManager: and self.cfg.fsdp_version == 2 ) - patch_prepare_from_posids() patch_evaluation_loop(patch_fsdp2) patch_maybe_log_save_evaluate() diff --git a/src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py b/src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py deleted file mode 100644 index 1bd8ac6bc..000000000 --- a/src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -Monkey patch to fix transformers.modeling_flash_attention_utils. - -see https://github.com/huggingface/transformers/pull/39653/files -""" - -import sys - -import torch - - -def _prepare_from_posids(query, key, value, position_ids): - """ - This function returns necessary arguments to call `flash_attn_varlen_func`. - All three query, key, value states will be flattened. - Cumulative lengths of each examples in the batch will be extracted from position_ids. - NOTE: ideally cumulative lengths should be prepared at the data collator stage - Arguments: - query (`torch.Tensor`): - Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). - key (`torch.Tensor`): - Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - value (`torch.Tensor`): - Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - position_ids (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - Return: - query (`torch.Tensor`): - Query state without padding. Shape: (total_target_length, num_heads, head_dim). - key (`torch.Tensor`): - Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - value (`torch.Tensor`): - Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - indices_q (`torch.Tensor`): - The indices of non-masked tokens from the flattened input target sequence. - (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): - The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). - """ - query = query.contiguous().view(-1, query.size(-2), query.size(-1)) - key = key.contiguous().view(-1, key.size(-2), key.size(-1)) - value = value.contiguous().view(-1, value.size(-2), value.size(-1)) - - position_ids = position_ids.flatten() - indices_q = torch.arange( - position_ids.size(0), device=position_ids.device, dtype=torch.int32 - ) - - cu_seq_lens = torch.cat( - ( - indices_q[position_ids == 0], - torch.tensor( - position_ids.size(), device=position_ids.device, dtype=torch.int32 - ), - ) - ) - # NOTE: With torch compile, this will cause a graph break if you don't set - # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call - # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass. - # This is a limitation of flash attention API, as the function `flash_attn_varlen_func` - # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`. - # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424 - # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing - # for some models (e.g. qwen2-vl). - max_length = cu_seq_lens.diff().max().item() - return ( - query, - key, - value, - indices_q, - (cu_seq_lens, cu_seq_lens), - (max_length, max_length), - ) - - -def patch_prepare_from_posids(): - import transformers.modeling_flash_attention_utils - - transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access - _prepare_from_posids - ) - setattr( - sys.modules["transformers.modeling_flash_attention_utils"], - "_prepare_from_posids", - _prepare_from_posids, - ) diff --git a/tests/e2e/multigpu/test_ray.py b/tests/e2e/multigpu/test_ray.py index dd1422296..7f1278abf 100644 --- a/tests/e2e/multigpu/test_ray.py +++ b/tests/e2e/multigpu/test_ray.py @@ -10,7 +10,11 @@ from accelerate.test_utils import execute_subprocess_async from axolotl.utils.dict import DictDefault -from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0 +from tests.e2e.utils import ( + check_tensorboard, + require_torch_2_7_0, + require_torch_lt_2_6_0, +) AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent @@ -139,3 +143,71 @@ class TestMultiGPURay: check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) + + @require_torch_2_7_0 + @pytest.mark.parametrize( + "gradient_accumulation_steps", + [1, 2], + ) + def test_sft_fsdp2_packed(self, temp_dir, gradient_accumulation_steps): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sample_packing": True, + "pad_to_sequence_len": True, + "sequence_len": 1024, + "val_set_size": 0.01, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 1, + "gradient_accumulation_steps": gradient_accumulation_steps, + "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "LlamaDecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "save_first_step": False, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--use-ray", + "--ray-num-workers", + "2", + ] + ) + + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" + )