Compare commits
11 Commits
patch_lora
...
seq-parall
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee489d16bf | ||
|
|
d88e071120 | ||
|
|
a4170030ab | ||
|
|
bf842730a5 | ||
|
|
1db6ad60a7 | ||
|
|
29b366b2e1 | ||
|
|
b53a41372f | ||
|
|
02f45e94be | ||
|
|
954e192f38 | ||
|
|
8dfadc2b3c | ||
|
|
23a9fcb0a7 |
4
.github/workflows/multi-gpu-e2e.yml
vendored
4
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -4,6 +4,10 @@ on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'tests/e2e/multigpu/*.py'
|
||||
- 'requirements.txt'
|
||||
- 'setup.py'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/multi-gpu-e2e.yml'
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||
|
||||
@@ -37,15 +37,11 @@ temp_dir = tempfile.mkdtemp()
|
||||
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||
f.write(dockerfile_contents)
|
||||
|
||||
cicd_image = (
|
||||
Image.from_dockerfile(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
force_build=True,
|
||||
gpu="A10G",
|
||||
)
|
||||
.env(df_args)
|
||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
)
|
||||
cicd_image = Image.from_dockerfile(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
force_build=True,
|
||||
gpu="A10G",
|
||||
).env(df_args)
|
||||
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
|
||||
|
||||
@@ -407,7 +407,10 @@ save_total_limit: # Checkpoints saved at a time
|
||||
max_steps:
|
||||
|
||||
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
|
||||
include_tokens_per_second:
|
||||
include_tokens_per_second: # Optional[bool]
|
||||
|
||||
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
|
||||
auto_find_batch_size: # Optional[bool]
|
||||
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
|
||||
@@ -12,6 +12,7 @@ to leverage operator fusion and tensor re-use in order to improve speed and redu
|
||||
memory usage during the forward and backward passes of these calculations.
|
||||
|
||||
We currently support several common model architectures, including (but not limited to):
|
||||
|
||||
- `llama`
|
||||
- `mistral`
|
||||
- `qwen2`
|
||||
|
||||
@@ -13,12 +13,12 @@ liger-kernel==0.5.2
|
||||
packaging==23.2
|
||||
|
||||
peft==0.14.0
|
||||
transformers==4.48.3
|
||||
transformers==4.49.0
|
||||
tokenizers>=0.21.0
|
||||
accelerate==1.3.0
|
||||
datasets==3.2.0
|
||||
deepspeed==0.16.1
|
||||
trl==0.15.0
|
||||
trl==0.15.1
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
|
||||
@@ -123,8 +123,6 @@ class ModalCloud(Cloud):
|
||||
if env := self.get_env():
|
||||
image = image.env(env)
|
||||
|
||||
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
|
||||
return image
|
||||
|
||||
def get_secrets(self):
|
||||
|
||||
@@ -59,6 +59,7 @@ from axolotl.core.training_args import (
|
||||
AxolotlTrainingArguments,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType, get_extract_fn
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
@@ -746,6 +747,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||
if self.cfg.sp_ulysses_degree:
|
||||
data_collator_kwargs["sp_extract_fn"] = get_extract_fn(
|
||||
USPRingAttnType.ZIGZAG,
|
||||
sp_ulysses_degree=self.cfg.sp_ulysses_degree
|
||||
)
|
||||
|
||||
if self.cfg.reward_model:
|
||||
data_collator_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
@@ -78,7 +78,6 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.merge_adapter()
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
unwrapped_model.unmerge_adapter()
|
||||
# Remove base_model and base_layer prefixes
|
||||
state_dict = {
|
||||
k.removeprefix("base_model.model.")
|
||||
@@ -100,8 +99,10 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
}
|
||||
else:
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = (
|
||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
)
|
||||
llm_model.load_weights(state_dict.items())
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = (
|
||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
)
|
||||
llm_model.load_weights(state_dict.items())
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.unmerge_adapter()
|
||||
|
||||
@@ -206,6 +206,16 @@ class AxolotlTrainingMixins:
|
||||
},
|
||||
)
|
||||
|
||||
sp_ulysses_degree: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Ulysses parallelism for hybrid sequence parallel long context attn"},
|
||||
)
|
||||
|
||||
sp_ring_degree: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Ring attention parallelism for sequence parallel long context attn"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from yunchang import set_seq_parallel_pg, EXTRACT_FUNC_DICT
|
||||
|
||||
from axolotl.utils.distributed import get_world_size, get_rank
|
||||
|
||||
|
||||
class USPRingAttnType(Enum):
|
||||
BASIC = "basic"
|
||||
ZIGZAG = "zigzag"
|
||||
STRIPE = "stripe"
|
||||
|
||||
def apply_usp_attn_patch(ring_impl_type: USPRingAttnType):
|
||||
from axolotl.monkeypatch.attention.sequence_parallel.usp import build_usp_fa_forward
|
||||
|
||||
fa_forward = build_usp_fa_forward(ring_impl_type)
|
||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = fa_forward
|
||||
|
||||
def get_extract_fn(ring_impl_type: USPRingAttnType, sp_ulysses_degree: int):
|
||||
fn = EXTRACT_FUNC_DICT["basic"]
|
||||
if ring_impl_type.value in EXTRACT_FUNC_DICT.keys():
|
||||
fn = EXTRACT_FUNC_DICT[ring_impl_type.value]
|
||||
|
||||
# map bad key upstream
|
||||
elif ring_impl_type == USPRingAttnType.STRIPE:
|
||||
fn = EXTRACT_FUNC_DICT["strip"]
|
||||
|
||||
world_size = get_world_size()
|
||||
rd = world_size // sp_ulysses_degree
|
||||
|
||||
return partial(fn, rank=get_rank(), world_size=world_size, rd=rd, ud=sp_ulysses_degree)
|
||||
|
||||
def set_usp_parallel_group(sp_ulysses_degree):
|
||||
"""
|
||||
setup distributed parallel group for USP attention
|
||||
make sure this gets called before building any USP attention modules
|
||||
:param sp_ulysses_degree:
|
||||
:return:
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
rank = get_rank()
|
||||
sp_ring_degree = world_size // sp_ulysses_degree
|
||||
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
|
||||
36
src/axolotl/monkeypatch/attention/sequence_parallel/usp.py
Normal file
36
src/axolotl/monkeypatch/attention/sequence_parallel/usp.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple, Callable
|
||||
|
||||
import torch
|
||||
from yunchang import LongContextAttention
|
||||
|
||||
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType
|
||||
|
||||
|
||||
def build_usp_fa_forward(ring_impl_type: USPRingAttnType) -> Callable:
|
||||
usp_attn = LongContextAttention(ring_impl_type.value)
|
||||
|
||||
def flash_attention_forward(
|
||||
module: torch.nn.Module, # pylint: disable=unused-argument
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor], # pylint: disable=unused-argument
|
||||
dropout: float = 0.0,
|
||||
scaling: Optional[float] = None,
|
||||
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
|
||||
softcap: Optional[float] = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, None]:
|
||||
attn_output = usp_attn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=scaling,
|
||||
causal=True,
|
||||
softcap=softcap,
|
||||
)
|
||||
return attn_output, None
|
||||
|
||||
return flash_attention_forward
|
||||
@@ -4,12 +4,13 @@ import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import types
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from peft import PeftModelForCausalLM
|
||||
from torch import nn
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers import AutoConfig
|
||||
|
||||
from axolotl.kernels.lora import (
|
||||
apply_lora_mlp_geglu,
|
||||
@@ -95,90 +96,108 @@ def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tens
|
||||
return attn_output
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def patch_self_attn_lora(model: PreTrainedModel):
|
||||
def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
"""
|
||||
Patches the attention classes in a transformer model with optimized LoRA implementations.
|
||||
Get the appropriate attention class by inspecting the model config.
|
||||
Uses dynamic import to support any model architecture that follows
|
||||
the standard transformers naming convention.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
The appropriate attention class for the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If `base_model` not specified or attention class cannot be imported
|
||||
ImportError: If the model module or attention class doesn't exist
|
||||
"""
|
||||
if "base_model" not in cfg:
|
||||
raise ValueError("base_model must be specified in config")
|
||||
|
||||
# Get model config without loading the model
|
||||
model_config = AutoConfig.from_pretrained(cfg["base_model"])
|
||||
model_type = model_config.model_type
|
||||
|
||||
# Special case for model_type = "qwen2"
|
||||
if model_type == "qwen2":
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
|
||||
|
||||
return Qwen2Attention
|
||||
|
||||
try:
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
module = __import__(
|
||||
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
|
||||
)
|
||||
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
|
||||
|
||||
return attention_cls
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(
|
||||
f"Could not import attention class for model_type: {model_type}. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def patch_self_attn_lora(cfg: DictDefault):
|
||||
"""
|
||||
Given an `axolotl` config, this method patches the inferred attention class forward
|
||||
pass with optimized LoRA implementations.
|
||||
|
||||
It modifies the attention class to use optimized QKV and output projections. The
|
||||
original implementation is preserved and can be restored if needed.
|
||||
|
||||
Args:
|
||||
model: A HuggingFace transformers model.
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the required code blocks are not found in the attention
|
||||
implementation.
|
||||
"""
|
||||
# Find all attention modules in the model
|
||||
attention_modules = [
|
||||
module
|
||||
for module in model.modules()
|
||||
if "attention" in module.__class__.__name__.lower()
|
||||
and hasattr(module, "forward")
|
||||
]
|
||||
attention_cls = get_attention_cls_from_config(cfg)
|
||||
|
||||
if not attention_modules:
|
||||
LOG.warning("No attention modules found in model")
|
||||
# Check if already patched
|
||||
if hasattr(attention_cls, "_original_forward"):
|
||||
LOG.info(f"{attention_cls.__name__} already patched")
|
||||
return
|
||||
|
||||
attention_classes = {type(module) for module in attention_modules}
|
||||
LOG.info(f"Found attention classes: {[cls.__name__ for cls in attention_classes]}")
|
||||
self_attn_forward = inspect.getsource(attention_cls.forward)
|
||||
attention_cls._original_forward = self_attn_forward
|
||||
self_attn_forward, _ = detab_code(self_attn_forward)
|
||||
|
||||
for attention_cls in attention_classes:
|
||||
# Skip if already patched
|
||||
if hasattr(attention_cls, "_original_forward"):
|
||||
LOG.info(f"{attention_cls.__name__} already patched")
|
||||
continue
|
||||
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
|
||||
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
|
||||
|
||||
# Get and store original forward implementation
|
||||
self_attn_forward = inspect.getsource(attention_cls.forward)
|
||||
attention_cls._original_forward = self_attn_forward
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(
|
||||
"def forward(",
|
||||
"def axolotl_attn_forward(",
|
||||
1,
|
||||
)
|
||||
|
||||
# Remove indentation
|
||||
self_attn_forward, _ = detab_code(self_attn_forward)
|
||||
# Load necessary imports
|
||||
module_name = attention_cls.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Verify required code blocks exist
|
||||
assert (
|
||||
ORIGINAL_QKV_CODE in self_attn_forward
|
||||
), f"Original QKV code not found in {attention_cls.__name__}"
|
||||
assert (
|
||||
ORIGINAL_O_CODE in self_attn_forward
|
||||
), f"Original O code not found in {attention_cls.__name__}"
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in self_attn_forward:
|
||||
items_to_import.append(item)
|
||||
|
||||
# Replace code blocks
|
||||
self_attn_forward = self_attn_forward.replace(
|
||||
ORIGINAL_QKV_CODE, PATCHED_QKV_CODE
|
||||
)
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(
|
||||
"def forward(",
|
||||
"def axolotl_attn_forward(",
|
||||
1,
|
||||
)
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
|
||||
# Import necessary symbols from the attention module
|
||||
module_name = attention_cls.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in self_attn_forward:
|
||||
items_to_import.append(item)
|
||||
|
||||
if items_to_import:
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
|
||||
# Execute the new implementation
|
||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
|
||||
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
|
||||
attention_cls.forward = (
|
||||
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
|
||||
attention_cls.forward = (
|
||||
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_kernel_patches(
|
||||
|
||||
@@ -127,6 +127,8 @@ class ReLoRACallback(TrainerCallback):
|
||||
optimizer: torch.optim.Optimizer,
|
||||
**_kwargs,
|
||||
):
|
||||
if not optimizer:
|
||||
optimizer = state.optimizer
|
||||
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
|
||||
@@ -272,8 +272,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
dict(zip(feature_names, row))
|
||||
)
|
||||
for key, val in tokenized_prompt.items():
|
||||
for i in range(0, len(val), self.sequence_len):
|
||||
res[key].append(val[i : i + self.sequence_len])
|
||||
res[key].append(val)
|
||||
|
||||
# If there are no examples left, return an empty dictionary
|
||||
if not res:
|
||||
|
||||
@@ -3,7 +3,7 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, Callable
|
||||
|
||||
import numpy as np
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@@ -53,6 +53,7 @@ class DataCollatorForSeq2Seq:
|
||||
label_pad_token_id: int = -100
|
||||
position_pad_token_id: int = 0
|
||||
return_tensors: str = "pt"
|
||||
sp_extract_fn: Optional[Callable] = None
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
labels = None
|
||||
@@ -121,6 +122,10 @@ class DataCollatorForSeq2Seq:
|
||||
|
||||
return features
|
||||
|
||||
def seq_parallel_split(self, features):
|
||||
if self.sp_extract_fn:
|
||||
pass
|
||||
return features
|
||||
|
||||
@dataclass
|
||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
@@ -342,6 +342,7 @@ class LoraConfig(BaseModel):
|
||||
peft_use_dora: Optional[bool] = None
|
||||
peft_use_rslora: Optional[bool] = None
|
||||
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
||||
peft_init_lora_weights: Optional[Union[bool, str]] = None
|
||||
|
||||
qlora_sharded_model_loading: Optional[bool] = Field(
|
||||
default=False,
|
||||
@@ -831,6 +832,8 @@ class AxolotlInputConfig(
|
||||
|
||||
eager_attention: Optional[bool] = None
|
||||
|
||||
sp_ulysses_degree: Optional[int] = None
|
||||
|
||||
unsloth_cross_entropy_loss: Optional[bool] = None
|
||||
unsloth_lora_mlp: Optional[bool] = None
|
||||
unsloth_lora_qkv: Optional[bool] = None
|
||||
|
||||
@@ -172,10 +172,11 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
||||
)
|
||||
|
||||
try:
|
||||
min_input_len = np.min(get_dataset_lengths(dataset))
|
||||
LOG.debug(f"min_input_len: {min_input_len}")
|
||||
max_input_len = np.max(get_dataset_lengths(dataset))
|
||||
LOG.debug(f"max_input_len: {max_input_len}")
|
||||
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
||||
min_input_len = np.min(ds_lengths)
|
||||
LOG.info(f"min_input_len: {min_input_len}")
|
||||
max_input_len = np.max(ds_lengths)
|
||||
LOG.info(f"max_input_len: {max_input_len}")
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -86,6 +86,12 @@ def get_world_size():
|
||||
return int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_distributed():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def zero_only():
|
||||
"""
|
||||
|
||||
@@ -439,6 +439,11 @@ class ModelLoader:
|
||||
|
||||
patch_mistral_cross_entropy()
|
||||
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora(self.cfg)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
if hasattr(self.model_config, "model_type"):
|
||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||
@@ -1023,12 +1028,6 @@ class ModelLoader:
|
||||
integrate_rope_embeddings()
|
||||
|
||||
def apply_lora_patch(self) -> None:
|
||||
"""Applies patching relevant to LoRA Triton kernels if enabled."""
|
||||
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
|
||||
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora(self.model)
|
||||
|
||||
if (
|
||||
self.cfg.lora_mlp_kernel
|
||||
or self.cfg.lora_qkv_kernel
|
||||
@@ -1182,7 +1181,6 @@ class ModelLoader:
|
||||
if self.cfg.adapter is not None:
|
||||
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
||||
|
||||
# TODO: Deprecate this.
|
||||
self.apply_unsloth_lora_patch()
|
||||
self.apply_lora_patch()
|
||||
|
||||
@@ -1203,7 +1201,9 @@ def load_model(
|
||||
reference_model: bool = False,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||
"""Load a model for a given configuration and tokenizer."""
|
||||
"""
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
loader = ModelLoader(
|
||||
cfg,
|
||||
tokenizer,
|
||||
@@ -1321,6 +1321,8 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
if loftq_bits:
|
||||
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
|
||||
lora_config_kwargs["init_lora_weights"] = "loftq"
|
||||
if cfg.peft_init_lora_weights:
|
||||
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
|
||||
if cfg.peft_use_dora:
|
||||
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
||||
LOG.info("Initializing LoRA weights using dora. This might take longer.")
|
||||
|
||||
@@ -4,13 +4,17 @@ helper util to calculate dataset lengths
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_dataset_lengths(dataset):
|
||||
if "length" in dataset.data.column_names:
|
||||
lengths = np.array(dataset.data.column("length"))
|
||||
elif "position_ids" in dataset.data.column_names:
|
||||
position_ids = dataset.data.column("position_ids")
|
||||
def get_dataset_lengths(dataset, from_arrow=False):
|
||||
if "length" in dataset.column_names:
|
||||
lengths = np.array(dataset["length"])
|
||||
elif "position_ids" in dataset.column_names:
|
||||
position_ids = dataset["position_ids"]
|
||||
lengths = np.array([x[-1] + 1 for x in position_ids])
|
||||
else:
|
||||
input_ids = dataset.data.column("input_ids")
|
||||
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
||||
if from_arrow:
|
||||
input_ids = dataset.data.column("input_ids")
|
||||
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
||||
else:
|
||||
input_ids = dataset["input_ids"]
|
||||
lengths = np.array([len(seq) for seq in input_ids])
|
||||
return lengths
|
||||
|
||||
@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Add position_id column (PoSE)",
|
||||
)
|
||||
elif cfg.sample_packing:
|
||||
elif cfg.sample_packing or cfg.sp_ulysses_degree:
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||
|
||||
@@ -9,14 +9,16 @@ from transformers import AutoModelForCausalLM, LlamaForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.kernels.lora import (
|
||||
apply_lora_mlp_geglu,
|
||||
apply_lora_mlp_swiglu,
|
||||
apply_lora_o,
|
||||
apply_lora_qkv,
|
||||
)
|
||||
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
|
||||
from axolotl.monkeypatch.lora_kernels import (
|
||||
apply_lora_kernel_patches,
|
||||
patch_self_attn_lora,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
MODEL_CONFIGS = [
|
||||
@@ -63,45 +65,15 @@ def small_llama_model():
|
||||
return LlamaForCausalLM(LlamaConfig(**config))
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
@pytest.fixture
|
||||
def minimal_cfg():
|
||||
"Config of real HuggingFace Hub model"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||
"learning_rate": 0.000001,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"sequence_len": 1024,
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
"lora_o_kernel": True,
|
||||
}
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def test_attention_patching_integration(minimal_cfg):
|
||||
def test_attention_patching_integration():
|
||||
"""Test attention patching in integration context."""
|
||||
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
|
||||
|
||||
# Store the original implementation
|
||||
original_forward = getattr(LlamaAttention, "forward")
|
||||
|
||||
# Load model
|
||||
_, _ = load_model_and_tokenizer(cfg=minimal_cfg)
|
||||
# Apply patch
|
||||
patch_self_attn_lora(cfg)
|
||||
|
||||
# Get the new forward method
|
||||
patched_forward = LlamaAttention.forward
|
||||
@@ -404,10 +376,38 @@ def test_model_architecture(model_config):
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def test_kernel_training_integration(minimal_cfg):
|
||||
def test_kernel_training_integration():
|
||||
"""Test model loading with kernel patches enabled."""
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
|
||||
# Create minimal config
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||
"learning_rate": 0.000001,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"sequence_len": 1024,
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
"lora_o_kernel": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Load model
|
||||
model, _ = load_model_and_tokenizer(cfg=minimal_cfg)
|
||||
model, _ = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
# Verify correct activation function
|
||||
layer = model.model.model.layers[0]
|
||||
|
||||
@@ -125,6 +125,12 @@ def fixture_llama3_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True)
|
||||
def fixture_smollm2_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
||||
def fixture_mistralv03_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
|
||||
61
tests/prompt_strategies/test_dpo_chatml.py
Normal file
61
tests/prompt_strategies/test_dpo_chatml.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Tests for loading DPO preference datasets with chatml formatting
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="minimal_dpo_cfg")
|
||||
def fixture_cfg():
|
||||
return DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||
"rl": "dpo",
|
||||
"learning_rate": 0.000001,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"sequence_len": 2048,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestDPOChatml:
|
||||
"""
|
||||
Test loading DPO preference datasets with chatml formatting
|
||||
"""
|
||||
|
||||
def test_default(self, minimal_dpo_cfg):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"datasets": [
|
||||
{
|
||||
"path": "argilla/distilabel-intel-orca-dpo-pairs",
|
||||
"type": "chatml",
|
||||
"split": "train[:1%]",
|
||||
}
|
||||
]
|
||||
}
|
||||
| minimal_dpo_cfg
|
||||
)
|
||||
|
||||
# test that dpo.load works
|
||||
load_dpo("chatml", cfg)
|
||||
# now actually load the datasets with the strategy
|
||||
train_ds, _ = load_prepare_preference_datasets(cfg)
|
||||
assert train_ds[0]["prompt"].startswith("<|im_start|>")
|
||||
assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n")
|
||||
assert "chosen" in train_ds[0]
|
||||
assert "rejected" in train_ds[0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -7,6 +7,7 @@ from transformers import AutoTokenizer
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies.completion import load
|
||||
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.data.utils import drop_long_seq_in_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
@@ -18,11 +19,6 @@ def fixture_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="max_seq_length")
|
||||
def fixture_max_seq_length():
|
||||
return 4096
|
||||
|
||||
|
||||
class TestBatchedSamplerPacking:
|
||||
"""
|
||||
Test class for packing streaming dataset sequences
|
||||
@@ -37,6 +33,7 @@ class TestBatchedSamplerPacking:
|
||||
(2, 2),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("max_seq_length", [4096, 512])
|
||||
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
@@ -62,6 +59,9 @@ class TestBatchedSamplerPacking:
|
||||
dataset,
|
||||
)
|
||||
train_dataset = concatenate_datasets([dataset_wrapper])
|
||||
|
||||
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg)
|
||||
|
||||
lengths = get_dataset_lengths(train_dataset)
|
||||
batch_sampler = MultipackBatchSampler(
|
||||
sampler=RandomSampler(train_dataset),
|
||||
@@ -90,7 +90,7 @@ class TestBatchedSamplerPacking:
|
||||
batch_idxs.extend(pack)
|
||||
|
||||
for batch in loader:
|
||||
assert len(batch["input_ids"]) <= batch_size * max_seq_length
|
||||
assert batch["input_ids"].numel() <= batch_size * max_seq_length
|
||||
assert batch["input_ids"].shape[1] == max_seq_length
|
||||
|
||||
original_idxs = set(range(len(train_dataset)))
|
||||
|
||||
Reference in New Issue
Block a user