Compare commits

..

5 Commits

Author SHA1 Message Date
Dan Saunders
822a8a6931 pylint 2025-02-18 19:59:17 +00:00
Dan Saunders
1a51180637 removing unused function 2025-02-18 19:36:03 +00:00
Dan Saunders
7562aadf89 fix 2025-02-18 19:13:09 +00:00
Dan Saunders
479f5e18dd Small updates 2025-02-18 19:08:27 +00:00
Dan Saunders
945dcc5020 move patching to post-model load to improve applicability 2025-02-18 19:00:12 +00:00
25 changed files with 152 additions and 360 deletions

View File

@@ -4,10 +4,6 @@ on:
pull_request: pull_request:
paths: paths:
- 'tests/e2e/multigpu/*.py' - 'tests/e2e/multigpu/*.py'
- 'requirements.txt'
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
workflow_dispatch: workflow_dispatch:
schedule: schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday - cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday

View File

@@ -37,11 +37,15 @@ temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f: with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents) f.write(dockerfile_contents)
cicd_image = Image.from_dockerfile( cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile", pathlib.Path(temp_dir) / "Dockerfile",
force_build=True, force_build=True,
gpu="A10G", gpu="A10G",
).env(df_args) )
.env(df_args)
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
app = App("Axolotl CI/CD", secrets=[]) app = App("Axolotl CI/CD", secrets=[])

View File

@@ -407,10 +407,7 @@ save_total_limit: # Checkpoints saved at a time
max_steps: max_steps:
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time. # bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
include_tokens_per_second: # Optional[bool] include_tokens_per_second:
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
auto_find_batch_size: # Optional[bool]
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128

View File

@@ -12,7 +12,6 @@ to leverage operator fusion and tensor re-use in order to improve speed and redu
memory usage during the forward and backward passes of these calculations. memory usage during the forward and backward passes of these calculations.
We currently support several common model architectures, including (but not limited to): We currently support several common model architectures, including (but not limited to):
- `llama` - `llama`
- `mistral` - `mistral`
- `qwen2` - `qwen2`

View File

@@ -13,12 +13,12 @@ liger-kernel==0.5.2
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.14.0
transformers==4.49.0 transformers==4.48.3
tokenizers>=0.21.0 tokenizers>=0.21.0
accelerate==1.3.0 accelerate==1.3.0
datasets==3.2.0 datasets==3.2.0
deepspeed==0.16.1 deepspeed==0.16.1
trl==0.15.1 trl==0.15.0
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer

View File

@@ -123,6 +123,8 @@ class ModalCloud(Cloud):
if env := self.get_env(): if env := self.get_env():
image = image.env(env) image = image.env(env)
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
return image return image
def get_secrets(self): def get_secrets(self):

View File

@@ -59,7 +59,6 @@ from axolotl.core.training_args import (
AxolotlTrainingArguments, AxolotlTrainingArguments,
) )
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType, get_extract_fn
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils import is_comet_available, is_mlflow_available
@@ -747,11 +746,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64 data_collator_kwargs["pad_to_multiple_of"] = 64
if self.cfg.sp_ulysses_degree:
data_collator_kwargs["sp_extract_fn"] = get_extract_fn(
USPRingAttnType.ZIGZAG,
sp_ulysses_degree=self.cfg.sp_ulysses_degree
)
if self.cfg.reward_model: if self.cfg.reward_model:
data_collator_kwargs["max_length"] = self.cfg.sequence_len data_collator_kwargs["max_length"] = self.cfg.sequence_len

View File

@@ -78,6 +78,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
if is_peft_model(unwrapped_model): if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter() unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict() state_dict = unwrapped_model.state_dict()
unwrapped_model.unmerge_adapter()
# Remove base_model and base_layer prefixes # Remove base_model and base_layer prefixes
state_dict = { state_dict = {
k.removeprefix("base_model.model.") k.removeprefix("base_model.model.")
@@ -104,5 +105,3 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
self.llm.llm_engine.model_executor.driver_worker.model_runner.model self.llm.llm_engine.model_executor.driver_worker.model_runner.model
) )
llm_model.load_weights(state_dict.items()) llm_model.load_weights(state_dict.items())
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()

View File

@@ -206,16 +206,6 @@ class AxolotlTrainingMixins:
}, },
) )
sp_ulysses_degree: Optional[int] = field(
default=None,
metadata={"help": "Ulysses parallelism for hybrid sequence parallel long context attn"},
)
sp_ring_degree: Optional[int] = field(
default=None,
metadata={"help": "Ring attention parallelism for sequence parallel long context attn"},
)
@dataclass @dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):

View File

@@ -1,45 +0,0 @@
from enum import Enum
from functools import partial
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from yunchang import set_seq_parallel_pg, EXTRACT_FUNC_DICT
from axolotl.utils.distributed import get_world_size, get_rank
class USPRingAttnType(Enum):
BASIC = "basic"
ZIGZAG = "zigzag"
STRIPE = "stripe"
def apply_usp_attn_patch(ring_impl_type: USPRingAttnType):
from axolotl.monkeypatch.attention.sequence_parallel.usp import build_usp_fa_forward
fa_forward = build_usp_fa_forward(ring_impl_type)
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = fa_forward
def get_extract_fn(ring_impl_type: USPRingAttnType, sp_ulysses_degree: int):
fn = EXTRACT_FUNC_DICT["basic"]
if ring_impl_type.value in EXTRACT_FUNC_DICT.keys():
fn = EXTRACT_FUNC_DICT[ring_impl_type.value]
# map bad key upstream
elif ring_impl_type == USPRingAttnType.STRIPE:
fn = EXTRACT_FUNC_DICT["strip"]
world_size = get_world_size()
rd = world_size // sp_ulysses_degree
return partial(fn, rank=get_rank(), world_size=world_size, rd=rd, ud=sp_ulysses_degree)
def set_usp_parallel_group(sp_ulysses_degree):
"""
setup distributed parallel group for USP attention
make sure this gets called before building any USP attention modules
:param sp_ulysses_degree:
:return:
"""
world_size = get_world_size()
rank = get_rank()
sp_ring_degree = world_size // sp_ulysses_degree
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)

View File

@@ -1,36 +0,0 @@
from enum import Enum
from typing import Optional, Tuple, Callable
import torch
from yunchang import LongContextAttention
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType
def build_usp_fa_forward(ring_impl_type: USPRingAttnType) -> Callable:
usp_attn = LongContextAttention(ring_impl_type.value)
def flash_attention_forward(
module: torch.nn.Module, # pylint: disable=unused-argument
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor], # pylint: disable=unused-argument
dropout: float = 0.0,
scaling: Optional[float] = None,
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
softcap: Optional[float] = None,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, None]:
attn_output = usp_attn(
query,
key,
value,
dropout_p=dropout,
softmax_scale=scaling,
causal=True,
softcap=softcap,
)
return attn_output, None
return flash_attention_forward

View File

@@ -4,13 +4,12 @@ import importlib
import inspect import inspect
import logging import logging
import types import types
from typing import Type
import torch import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
from peft import PeftModelForCausalLM from peft import PeftModelForCausalLM
from torch import nn from torch import nn
from transformers import AutoConfig from transformers.modeling_utils import PreTrainedModel
from axolotl.kernels.lora import ( from axolotl.kernels.lora import (
apply_lora_mlp_geglu, apply_lora_mlp_geglu,
@@ -96,82 +95,61 @@ def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tens
return attn_output return attn_output
def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
"""
Get the appropriate attention class by inspecting the model config.
Uses dynamic import to support any model architecture that follows
the standard transformers naming convention.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
The appropriate attention class for the model.
Raises:
ValueError: If `base_model` not specified or attention class cannot be imported
ImportError: If the model module or attention class doesn't exist
"""
if "base_model" not in cfg:
raise ValueError("base_model must be specified in config")
# Get model config without loading the model
model_config = AutoConfig.from_pretrained(cfg["base_model"])
model_type = model_config.model_type
# Special case for model_type = "qwen2"
if model_type == "qwen2":
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
return Qwen2Attention
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = __import__(
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
)
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
return attention_cls
except (ImportError, AttributeError) as e:
raise ValueError(
f"Could not import attention class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e
# pylint: disable=protected-access # pylint: disable=protected-access
def patch_self_attn_lora(cfg: DictDefault): def patch_self_attn_lora(model: PreTrainedModel):
""" """
Given an `axolotl` config, this method patches the inferred attention class forward Patches the attention classes in a transformer model with optimized LoRA implementations.
pass with optimized LoRA implementations.
It modifies the attention class to use optimized QKV and output projections. The It modifies the attention class to use optimized QKV and output projections. The
original implementation is preserved and can be restored if needed. original implementation is preserved and can be restored if needed.
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. model: A HuggingFace transformers model.
Raises: Raises:
AssertionError: If the required code blocks are not found in the attention AssertionError: If the required code blocks are not found in the attention
implementation. implementation.
""" """
attention_cls = get_attention_cls_from_config(cfg) # Find all attention modules in the model
attention_modules = [
module
for module in model.modules()
if "attention" in module.__class__.__name__.lower()
and hasattr(module, "forward")
]
# Check if already patched if not attention_modules:
if hasattr(attention_cls, "_original_forward"): LOG.warning("No attention modules found in model")
LOG.info(f"{attention_cls.__name__} already patched")
return return
attention_classes = {type(module) for module in attention_modules}
LOG.info(f"Found attention classes: {[cls.__name__ for cls in attention_classes]}")
for attention_cls in attention_classes:
# Skip if already patched
if hasattr(attention_cls, "_original_forward"):
LOG.info(f"{attention_cls.__name__} already patched")
continue
# Get and store original forward implementation
self_attn_forward = inspect.getsource(attention_cls.forward) self_attn_forward = inspect.getsource(attention_cls.forward)
attention_cls._original_forward = self_attn_forward attention_cls._original_forward = self_attn_forward
# Remove indentation
self_attn_forward, _ = detab_code(self_attn_forward) self_attn_forward, _ = detab_code(self_attn_forward)
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found" # Verify required code blocks exist
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found" assert (
ORIGINAL_QKV_CODE in self_attn_forward
), f"Original QKV code not found in {attention_cls.__name__}"
assert (
ORIGINAL_O_CODE in self_attn_forward
), f"Original O code not found in {attention_cls.__name__}"
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE) # Replace code blocks
self_attn_forward = self_attn_forward.replace(
ORIGINAL_QKV_CODE, PATCHED_QKV_CODE
)
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE) self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace( self_attn_forward = self_attn_forward.replace(
"def forward(", "def forward(",
@@ -179,7 +157,7 @@ def patch_self_attn_lora(cfg: DictDefault):
1, 1,
) )
# Load necessary imports # Import necessary symbols from the attention module
module_name = attention_cls.__module__ module_name = attention_cls.__module__
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
@@ -188,10 +166,13 @@ def patch_self_attn_lora(cfg: DictDefault):
if item in self_attn_forward: if item in self_attn_forward:
items_to_import.append(item) items_to_import.append(item)
if items_to_import:
exec( # pylint: disable=exec-used # nosec B102 exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})", f"from {module_name} import ({', '.join(items_to_import)})",
globals(), globals(),
) )
# Execute the new implementation
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}") LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")

View File

@@ -127,8 +127,6 @@ class ReLoRACallback(TrainerCallback):
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
**_kwargs, **_kwargs,
): ):
if not optimizer:
optimizer = state.optimizer
if state.global_step > 0 and state.global_step % self.relora_steps == 0: if state.global_step > 0 and state.global_step % self.relora_steps == 0:
checkpoint_folder = os.path.join( checkpoint_folder = os.path.join(
args.output_dir, args.output_dir,

View File

@@ -272,7 +272,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
dict(zip(feature_names, row)) dict(zip(feature_names, row))
) )
for key, val in tokenized_prompt.items(): for key, val in tokenized_prompt.items():
res[key].append(val) for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])
# If there are no examples left, return an empty dictionary # If there are no examples left, return an empty dictionary
if not res: if not res:

View File

@@ -3,7 +3,7 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Union, Callable from typing import Any, Optional, Union
import numpy as np import numpy as np
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@@ -53,7 +53,6 @@ class DataCollatorForSeq2Seq:
label_pad_token_id: int = -100 label_pad_token_id: int = -100
position_pad_token_id: int = 0 position_pad_token_id: int = 0
return_tensors: str = "pt" return_tensors: str = "pt"
sp_extract_fn: Optional[Callable] = None
def __call__(self, features, return_tensors=None): def __call__(self, features, return_tensors=None):
labels = None labels = None
@@ -122,10 +121,6 @@ class DataCollatorForSeq2Seq:
return features return features
def seq_parallel_split(self, features):
if self.sp_extract_fn:
pass
return features
@dataclass @dataclass
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):

View File

@@ -342,7 +342,6 @@ class LoraConfig(BaseModel):
peft_use_dora: Optional[bool] = None peft_use_dora: Optional[bool] = None
peft_use_rslora: Optional[bool] = None peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None peft_layer_replication: Optional[List[Tuple[int, int]]] = None
peft_init_lora_weights: Optional[Union[bool, str]] = None
qlora_sharded_model_loading: Optional[bool] = Field( qlora_sharded_model_loading: Optional[bool] = Field(
default=False, default=False,
@@ -832,8 +831,6 @@ class AxolotlInputConfig(
eager_attention: Optional[bool] = None eager_attention: Optional[bool] = None
sp_ulysses_degree: Optional[int] = None
unsloth_cross_entropy_loss: Optional[bool] = None unsloth_cross_entropy_loss: Optional[bool] = None
unsloth_lora_mlp: Optional[bool] = None unsloth_lora_mlp: Optional[bool] = None
unsloth_lora_qkv: Optional[bool] = None unsloth_lora_qkv: Optional[bool] = None

View File

@@ -172,11 +172,10 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
) )
try: try:
ds_lengths = get_dataset_lengths(dataset, from_arrow=True) min_input_len = np.min(get_dataset_lengths(dataset))
min_input_len = np.min(ds_lengths) LOG.debug(f"min_input_len: {min_input_len}")
LOG.info(f"min_input_len: {min_input_len}") max_input_len = np.max(get_dataset_lengths(dataset))
max_input_len = np.max(ds_lengths) LOG.debug(f"max_input_len: {max_input_len}")
LOG.info(f"max_input_len: {max_input_len}")
except AttributeError: except AttributeError:
pass pass

View File

@@ -86,12 +86,6 @@ def get_world_size():
return int(os.getenv("WORLD_SIZE", "1")) return int(os.getenv("WORLD_SIZE", "1"))
def get_rank():
if not is_distributed():
return 0
return dist.get_rank()
@contextmanager @contextmanager
def zero_only(): def zero_only():
""" """

View File

@@ -439,11 +439,6 @@ class ModelLoader:
patch_mistral_cross_entropy() patch_mistral_cross_entropy()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.cfg)
def patch_attention(self) -> None: def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"): if hasattr(self.model_config, "model_type"):
if self.model_config.model_type == "mllama" and self.cfg.flash_attention: if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
@@ -1028,6 +1023,12 @@ class ModelLoader:
integrate_rope_embeddings() integrate_rope_embeddings()
def apply_lora_patch(self) -> None: def apply_lora_patch(self) -> None:
"""Applies patching relevant to LoRA Triton kernels if enabled."""
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.model)
if ( if (
self.cfg.lora_mlp_kernel self.cfg.lora_mlp_kernel
or self.cfg.lora_qkv_kernel or self.cfg.lora_qkv_kernel
@@ -1181,6 +1182,7 @@ class ModelLoader:
if self.cfg.adapter is not None: if self.cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", self.model.device) log_gpu_memory_usage(LOG, "after adapters", self.model.device)
# TODO: Deprecate this.
self.apply_unsloth_lora_patch() self.apply_unsloth_lora_patch()
self.apply_lora_patch() self.apply_lora_patch()
@@ -1201,9 +1203,7 @@ def load_model(
reference_model: bool = False, reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: ) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
""" """Load a model for a given configuration and tokenizer."""
Load a model for a given configuration and tokenizer.
"""
loader = ModelLoader( loader = ModelLoader(
cfg, cfg,
tokenizer, tokenizer,
@@ -1321,8 +1321,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
if loftq_bits: if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq" lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_init_lora_weights:
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
if cfg.peft_use_dora: if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora lora_config_kwargs["use_dora"] = cfg.peft_use_dora
LOG.info("Initializing LoRA weights using dora. This might take longer.") LOG.info("Initializing LoRA weights using dora. This might take longer.")

View File

@@ -4,17 +4,13 @@ helper util to calculate dataset lengths
import numpy as np import numpy as np
def get_dataset_lengths(dataset, from_arrow=False): def get_dataset_lengths(dataset):
if "length" in dataset.column_names: if "length" in dataset.data.column_names:
lengths = np.array(dataset["length"]) lengths = np.array(dataset.data.column("length"))
elif "position_ids" in dataset.column_names: elif "position_ids" in dataset.data.column_names:
position_ids = dataset["position_ids"] position_ids = dataset.data.column("position_ids")
lengths = np.array([x[-1] + 1 for x in position_ids]) lengths = np.array([x[-1] + 1 for x in position_ids])
else: else:
if from_arrow:
input_ids = dataset.data.column("input_ids") input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
else:
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
return lengths return lengths

View File

@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)", desc="Add position_id column (PoSE)",
) )
elif cfg.sample_packing or cfg.sp_ulysses_degree: elif cfg.sample_packing:
drop_long_kwargs = {} drop_long_kwargs = {}
if filter_map_kwargs: if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)" drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"

View File

@@ -9,16 +9,14 @@ from transformers import AutoModelForCausalLM, LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.llama.modeling_llama import LlamaAttention
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.kernels.lora import ( from axolotl.kernels.lora import (
apply_lora_mlp_geglu, apply_lora_mlp_geglu,
apply_lora_mlp_swiglu, apply_lora_mlp_swiglu,
apply_lora_o, apply_lora_o,
apply_lora_qkv, apply_lora_qkv,
) )
from axolotl.monkeypatch.lora_kernels import ( from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
apply_lora_kernel_patches,
patch_self_attn_lora,
)
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
MODEL_CONFIGS = [ MODEL_CONFIGS = [
@@ -65,15 +63,45 @@ def small_llama_model():
return LlamaForCausalLM(LlamaConfig(**config)) return LlamaForCausalLM(LlamaConfig(**config))
def test_attention_patching_integration(): # pylint: disable=duplicate-code
"""Test attention patching in integration context.""" @pytest.fixture
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"} def minimal_cfg():
"Config of real HuggingFace Hub model"
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)
return cfg
def test_attention_patching_integration(minimal_cfg):
"""Test attention patching in integration context."""
# Store the original implementation # Store the original implementation
original_forward = getattr(LlamaAttention, "forward") original_forward = getattr(LlamaAttention, "forward")
# Apply patch # Load model
patch_self_attn_lora(cfg) _, _ = load_model_and_tokenizer(cfg=minimal_cfg)
# Get the new forward method # Get the new forward method
patched_forward = LlamaAttention.forward patched_forward = LlamaAttention.forward
@@ -376,38 +404,10 @@ def test_model_architecture(model_config):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
def test_kernel_training_integration(): def test_kernel_training_integration(minimal_cfg):
"""Test model loading with kernel patches enabled.""" """Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer
# Create minimal config
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)
# Load model # Load model
model, _ = load_model_and_tokenizer(cfg=cfg) model, _ = load_model_and_tokenizer(cfg=minimal_cfg)
# Verify correct activation function # Verify correct activation function
layer = model.model.model.layers[0] layer = model.model.model.layers[0]

View File

@@ -125,12 +125,6 @@ def fixture_llama3_tokenizer():
return tokenizer return tokenizer
@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True)
def fixture_smollm2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
return tokenizer
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True) @pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
def fixture_mistralv03_tokenizer(): def fixture_mistralv03_tokenizer():
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(

View File

@@ -1,61 +0,0 @@
"""
Tests for loading DPO preference datasets with chatml formatting
"""
import unittest
import pytest
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="minimal_dpo_cfg")
def fixture_cfg():
return DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"rl": "dpo",
"learning_rate": 0.000001,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"sequence_len": 2048,
}
)
class TestDPOChatml:
"""
Test loading DPO preference datasets with chatml formatting
"""
def test_default(self, minimal_dpo_cfg):
cfg = DictDefault(
{
"datasets": [
{
"path": "argilla/distilabel-intel-orca-dpo-pairs",
"type": "chatml",
"split": "train[:1%]",
}
]
}
| minimal_dpo_cfg
)
# test that dpo.load works
load_dpo("chatml", cfg)
# now actually load the datasets with the strategy
train_ds, _ = load_prepare_preference_datasets(cfg)
assert train_ds[0]["prompt"].startswith("<|im_start|>")
assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n")
assert "chosen" in train_ds[0]
assert "rejected" in train_ds[0]
if __name__ == "__main__":
unittest.main()

View File

@@ -7,7 +7,6 @@ from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.completion import load from axolotl.prompt_strategies.completion import load
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import drop_long_seq_in_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -19,6 +18,11 @@ def fixture_tokenizer():
return tokenizer return tokenizer
@pytest.fixture(name="max_seq_length")
def fixture_max_seq_length():
return 4096
class TestBatchedSamplerPacking: class TestBatchedSamplerPacking:
""" """
Test class for packing streaming dataset sequences Test class for packing streaming dataset sequences
@@ -33,7 +37,6 @@ class TestBatchedSamplerPacking:
(2, 2), (2, 2),
], ],
) )
@pytest.mark.parametrize("max_seq_length", [4096, 512])
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length): def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
@@ -59,9 +62,6 @@ class TestBatchedSamplerPacking:
dataset, dataset,
) )
train_dataset = concatenate_datasets([dataset_wrapper]) train_dataset = concatenate_datasets([dataset_wrapper])
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg)
lengths = get_dataset_lengths(train_dataset) lengths = get_dataset_lengths(train_dataset)
batch_sampler = MultipackBatchSampler( batch_sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset), sampler=RandomSampler(train_dataset),
@@ -90,7 +90,7 @@ class TestBatchedSamplerPacking:
batch_idxs.extend(pack) batch_idxs.extend(pack)
for batch in loader: for batch in loader:
assert batch["input_ids"].numel() <= batch_size * max_seq_length assert len(batch["input_ids"]) <= batch_size * max_seq_length
assert batch["input_ids"].shape[1] == max_seq_length assert batch["input_ids"].shape[1] == max_seq_length
original_idxs = set(range(len(train_dataset))) original_idxs = set(range(len(train_dataset)))