Merge branch 'main' into fix-preview
This commit is contained in:
@@ -3,7 +3,7 @@ default_language_version:
|
|||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v5.0.0
|
rev: v6.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
@@ -23,7 +23,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/pylint-dev/pylint
|
- repo: https://github.com/pylint-dev/pylint
|
||||||
rev: v3.3.7
|
rev: v3.3.8
|
||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
|
|||||||
@@ -124,6 +124,9 @@ def launch_training(
|
|||||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)
|
_launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||||
elif launcher == "python":
|
elif launcher == "python":
|
||||||
_launch_python_training(cfg_file, kwargs)
|
_launch_python_training(cfg_file, kwargs)
|
||||||
|
elif launcher is None:
|
||||||
|
# handle ray train launch
|
||||||
|
_launch_python_training(cfg_file, kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _launch_cloud_training(
|
def _launch_cloud_training(
|
||||||
|
|||||||
@@ -73,9 +73,6 @@ class PatchManager:
|
|||||||
self._apply_voxtral_patches()
|
self._apply_voxtral_patches()
|
||||||
|
|
||||||
def _apply_transformers_patches(self):
|
def _apply_transformers_patches(self):
|
||||||
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
|
|
||||||
patch_prepare_from_posids,
|
|
||||||
)
|
|
||||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||||
patch_evaluation_loop,
|
patch_evaluation_loop,
|
||||||
patch_maybe_log_save_evaluate,
|
patch_maybe_log_save_evaluate,
|
||||||
@@ -87,7 +84,6 @@ class PatchManager:
|
|||||||
and self.cfg.fsdp_version == 2
|
and self.cfg.fsdp_version == 2
|
||||||
)
|
)
|
||||||
|
|
||||||
patch_prepare_from_posids()
|
|
||||||
patch_evaluation_loop(patch_fsdp2)
|
patch_evaluation_loop(patch_fsdp2)
|
||||||
patch_maybe_log_save_evaluate()
|
patch_maybe_log_save_evaluate()
|
||||||
|
|
||||||
|
|||||||
@@ -1,87 +0,0 @@
|
|||||||
"""
|
|
||||||
Monkey patch to fix transformers.modeling_flash_attention_utils.
|
|
||||||
|
|
||||||
see https://github.com/huggingface/transformers/pull/39653/files
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_from_posids(query, key, value, position_ids):
|
|
||||||
"""
|
|
||||||
This function returns necessary arguments to call `flash_attn_varlen_func`.
|
|
||||||
All three query, key, value states will be flattened.
|
|
||||||
Cumulative lengths of each examples in the batch will be extracted from position_ids.
|
|
||||||
NOTE: ideally cumulative lengths should be prepared at the data collator stage
|
|
||||||
Arguments:
|
|
||||||
query (`torch.Tensor`):
|
|
||||||
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
|
||||||
key (`torch.Tensor`):
|
|
||||||
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
|
||||||
value (`torch.Tensor`):
|
|
||||||
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
|
||||||
position_ids (`torch.Tensor`):
|
|
||||||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
|
||||||
Return:
|
|
||||||
query (`torch.Tensor`):
|
|
||||||
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
|
||||||
key (`torch.Tensor`):
|
|
||||||
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
|
||||||
value (`torch.Tensor`):
|
|
||||||
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
|
||||||
indices_q (`torch.Tensor`):
|
|
||||||
The indices of non-masked tokens from the flattened input target sequence.
|
|
||||||
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
|
|
||||||
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
|
|
||||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
|
|
||||||
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
|
||||||
"""
|
|
||||||
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
|
|
||||||
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
|
|
||||||
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
|
|
||||||
|
|
||||||
position_ids = position_ids.flatten()
|
|
||||||
indices_q = torch.arange(
|
|
||||||
position_ids.size(0), device=position_ids.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
cu_seq_lens = torch.cat(
|
|
||||||
(
|
|
||||||
indices_q[position_ids == 0],
|
|
||||||
torch.tensor(
|
|
||||||
position_ids.size(), device=position_ids.device, dtype=torch.int32
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# NOTE: With torch compile, this will cause a graph break if you don't set
|
|
||||||
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
|
|
||||||
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
|
|
||||||
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
|
|
||||||
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
|
|
||||||
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
|
|
||||||
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
|
|
||||||
# for some models (e.g. qwen2-vl).
|
|
||||||
max_length = cu_seq_lens.diff().max().item()
|
|
||||||
return (
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
indices_q,
|
|
||||||
(cu_seq_lens, cu_seq_lens),
|
|
||||||
(max_length, max_length),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_prepare_from_posids():
|
|
||||||
import transformers.modeling_flash_attention_utils
|
|
||||||
|
|
||||||
transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access
|
|
||||||
_prepare_from_posids
|
|
||||||
)
|
|
||||||
setattr(
|
|
||||||
sys.modules["transformers.modeling_flash_attention_utils"],
|
|
||||||
"_prepare_from_posids",
|
|
||||||
_prepare_from_posids,
|
|
||||||
)
|
|
||||||
@@ -10,7 +10,11 @@ from accelerate.test_utils import execute_subprocess_async
|
|||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
from tests.e2e.utils import (
|
||||||
|
check_tensorboard,
|
||||||
|
require_torch_2_7_0,
|
||||||
|
require_torch_lt_2_6_0,
|
||||||
|
)
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
@@ -139,3 +143,71 @@ class TestMultiGPURay:
|
|||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch_2_7_0
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"gradient_accumulation_steps",
|
||||||
|
[1, 2],
|
||||||
|
)
|
||||||
|
def test_sft_fsdp2_packed(self, temp_dir, gradient_accumulation_steps):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"sample_packing": True,
|
||||||
|
"pad_to_sequence_len": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"val_set_size": 0.01,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 2,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"fsdp_version": 2,
|
||||||
|
"fsdp_config": {
|
||||||
|
"offload_params": False,
|
||||||
|
"cpu_ram_efficient_loading": False,
|
||||||
|
"transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
|
"state_dict_type": "FULL_STATE_DICT",
|
||||||
|
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
"reshard_after_forward": True,
|
||||||
|
},
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"save_first_step": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--use-ray",
|
||||||
|
"--ray-num-workers",
|
||||||
|
"2",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user