Compare commits

..

16 Commits

Author SHA1 Message Date
Dan Saunders
5306c6acbb fix 2025-04-24 22:35:18 +00:00
Dan Saunders
4ae8df16a9 adding all batch ring-flash-attn methods via single adapter 2025-04-24 22:35:18 +00:00
Dan Saunders
74e7cfd28f update 2025-04-24 22:35:18 +00:00
Dan Saunders
2bb5c1fe7e batch api HF adapter for ring-flash-attn; cleanup and improvements 2025-04-24 22:35:18 +00:00
Dan Saunders
3f1873cc62 pytest 2025-04-24 22:31:24 +00:00
Dan Saunders
072df89e0e add gather post hook, simplify, fixes 2025-04-24 22:31:24 +00:00
Dan Saunders
cb7c3ee847 tweak codecov yaml 2025-04-24 22:31:24 +00:00
Dan Saunders
d92ac7a41d reorg 2025-04-24 22:31:24 +00:00
Dan Saunders
5816433121 nit 2025-04-24 22:31:24 +00:00
Dan Saunders
e5a4e21497 simplifying 2025-04-24 22:31:24 +00:00
Dan Saunders
65ae78009c simplifying 2025-04-24 22:31:24 +00:00
Dan Saunders
7e5168ad74 accommodate both training context managers 2025-04-24 22:31:24 +00:00
Dan Saunders
cd393fecc3 further simplifying 2025-04-24 22:31:24 +00:00
Dan Saunders
bac5568bda update 2025-04-24 22:31:24 +00:00
Dan Saunders
69aeae80ed updates 2025-04-24 22:31:24 +00:00
Dan Saunders
cafda804ec ctx manager for SP 2025-04-24 22:31:24 +00:00
33 changed files with 64 additions and 826 deletions

View File

@@ -24,7 +24,7 @@ jobs:
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras: vllm
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"

View File

@@ -8,7 +8,6 @@ on:
- 'setup.py' - 'setup.py'
- 'pyproject.toml' - 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml' - '.github/workflows/multi-gpu-e2e.yml'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
workflow_dispatch: workflow_dispatch:
schedule: schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday - cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
@@ -43,7 +42,7 @@ jobs:
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras: vllm
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"
- cuda: 126 - cuda: 126

View File

@@ -258,12 +258,6 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: llmcompressor
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
@@ -275,7 +269,7 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras: vllm
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"

View File

@@ -52,4 +52,4 @@ pytest -v --durations=10 \
--cov-append \ --cov-append \
--cov-report=xml:e2e-coverage.xml --cov-report=xml:e2e-coverage.xml
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION} || true codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}

View File

@@ -20,4 +20,4 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov-report=xml:multigpu-coverage.xml --cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov # Upload coverage to Codecov
codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true codecov upload-process -t $CODECOV_TOKEN -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}

View File

@@ -49,8 +49,7 @@ sections = [
("Knowledge Distillation (KD)", "kd"), ("Knowledge Distillation (KD)", "kd"),
("Liger Kernels", "liger"), ("Liger Kernels", "liger"),
("Language Model Evaluation Harness (LM Eval)", "lm_eval"), ("Language Model Evaluation Harness (LM Eval)", "lm_eval"),
("Spectrum", "spectrum"), ("Spectrum", "spectrum")
("LLMCompressor", "llm_compressor")
] ]
for section_name, folder_name in sections: for section_name, folder_name in sections:

View File

@@ -28,8 +28,6 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
Tags examples: Tags examples:
- `main-base-py3.11-cu128-2.7.0`
- `main-base-py3.11-cu126-2.7.0`
- `main-base-py3.11-cu124-2.6.0` - `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1` - `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1` - `main-base-py3.11-cu124-2.4.1`
@@ -52,7 +50,7 @@ Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
# on push to main # on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version} main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.6.0, python 3.11, cuda 12.4) # latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
main-latest main-latest
# nightly build # nightly build
@@ -70,7 +68,6 @@ There may be some extra tags appended to the image, like `-vllm` which installs
Tags examples: Tags examples:
- `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu124-2.6.0` - `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1` - `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1` - `main-py3.11-cu124-2.4.1`

View File

@@ -1,77 +0,0 @@
base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
plugins:
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>
llmcompressor:
recipe:
finetuning_stage:
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
save_compressed: true

View File

@@ -10,6 +10,7 @@ plugins:
liger_glu_activation: true liger_glu_activation: true
liger_rms_norm: true liger_rms_norm: true
liger_layer_norm: true liger_layer_norm: true
cut_cross_entropy: true
llama4_linearized_experts: true # needed with custom linearized experts model llama4_linearized_experts: true # needed with custom linearized experts model
load_in_4bit: true load_in_4bit: true

View File

@@ -11,13 +11,13 @@ liger-kernel==0.5.8
packaging==23.2 packaging==23.2
peft==0.15.2 peft==0.15.1
transformers==4.51.3 transformers==4.51.3
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.6.0 accelerate==1.6.0
datasets==3.5.0 datasets==3.5.0
deepspeed>=0.15.4 deepspeed>=0.15.4
trl==0.17.0 trl==0.16.1
hf_xet==1.0.0 hf_xet==1.0.0
hqq==0.2.5 hqq==0.2.5

View File

@@ -67,13 +67,13 @@ def parse_requirements(extras_require_map):
if (major, minor) >= (2, 7): if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0 # _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.4"] extras_require_map["vllm"] = ["vllm==0.8.3"]
elif (major, minor) >= (2, 6): elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append( _install_requires.append(
"xformers==0.0.29.post2" "xformers==0.0.29.post2"
) # vllm needs post2 w torch 2.6 ) # vllm needs post2 w torch 2.6
extras_require_map["vllm"] = ["vllm==0.8.4"] extras_require_map["vllm"] = ["vllm==0.8.3"]
elif (major, minor) >= (2, 5): elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if patch == 0: if patch == 0:
@@ -149,9 +149,6 @@ extras_require = {
"vllm": [ "vllm": [
"vllm==0.7.2", "vllm==0.7.2",
], ],
"llmcompressor": [
"llmcompressor==0.5.1",
],
} }
install_requires, dependency_links, extras_require_build = parse_requirements( install_requires, dependency_links, extras_require_build = parse_requirements(

View File

@@ -1048,9 +1048,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rpo_alpha is not None: if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_cls = None training_args_cls = None
blocklist_args_kwargs = [] blocklist_args_kwargs = []
if self.cfg.rl == "simpo": if self.cfg.rl == "simpo":
@@ -1121,12 +1118,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
**training_args_kwargs, **training_args_kwargs,
) )
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
return training_args return training_args
def build(self, total_num_steps): def build(self, total_num_steps):

View File

@@ -135,9 +135,7 @@ class GRPOStrategy:
try: try:
# use importlib to dynamically load the reward function from the module # use importlib to dynamically load the reward function from the module
reward_func_module_name = reward_func_fqn.split(".")[-1] reward_func_module_name = reward_func_fqn.split(".")[-1]
reward_func_module = importlib.import_module( reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
".".join(reward_func_fqn.split(".")[:-1])
)
reward_func = getattr(reward_func_module, reward_func_module_name) reward_func = getattr(reward_func_module, reward_func_module_name)
if not len(inspect.signature(reward_func).parameters) >= 2: if not len(inspect.signature(reward_func).parameters) >= 2:
raise ValueError( raise ValueError(

View File

@@ -27,6 +27,8 @@ pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transform
```yaml ```yaml
plugins: plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
``` ```
## Supported Models ## Supported Models

View File

@@ -28,7 +28,7 @@ class CutCrossEntropyArgs(BaseModel):
Input args for Cut Cross Entropy. Input args for Cut Cross Entropy.
""" """
cut_cross_entropy: Optional[bool] = True cut_cross_entropy: Optional[bool] = None
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View File

@@ -1,108 +0,0 @@
# LLMCompressor Integration
Fine-tune sparsified models in Axolotl using Neural Magic's [LLMCompressor](https://github.com/vllm-project/llm-compressor).
This integration enables fine-tuning of models sparsified using LLMCompressor within the Axolotl training framework. By combining LLMCompressor's model compression capabilities with Axolotl's distributed training pipelines, users can efficiently fine-tune sparse models at scale.
It uses Axolotls plugin system to hook into the fine-tuning flows while maintaining sparsity throughout training.
---
## Requirements
- Axolotl with `llmcompressor` extras:
```bash
pip install "axolotl[llmcompressor]"
```
- Requires `llmcompressor >= 0.5.1`
This will install all necessary dependencies to fine-tune sparsified models using the integration.
---
## Usage
To enable sparse fine-tuning with this integration, include the plugin in your Axolotl config:
```yaml
plugins:
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
llmcompressor:
recipe:
finetuning_stage:
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
save_compressed: true
# ... (other training arguments)
```
This plugin **does not apply pruning or sparsification itself** — it is intended for **fine-tuning models that have already been sparsified**.
Pre-sparsified checkpoints can be:
- Generated using [LLMCompressor](https://github.com/vllm-project/llm-compressor)
- Downloaded from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic)
- Any custom LLM with compatible sparsity patterns that you've created yourself
To learn more about writing and customizing LLMCompressor recipes, refer to the official documentation:
[https://github.com/vllm-project/llm-compressor/blob/main/README.md](https://github.com/vllm-project/llm-compressor/blob/main/README.md)
### Storage Optimization with save_compressed
Setting `save_compressed: true` in your configuration enables saving models in a compressed format, which:
- Reduces disk space usage by approximately 40%
- Maintains compatibility with vLLM for accelerated inference
- Maintains compatibility with llmcompressor for further optimization (example: quantization)
This option is highly recommended when working with sparse models to maximize the benefits of model compression.
### Example Config
See [`examples/llama-3/sparse-finetuning.yaml`](examples/llama-3/sparse-finetuning.yaml) for a complete example.
---
## Inference with vLLM
After fine-tuning your sparse model, you can leverage vLLM for efficient inference.
You can also use LLMCompressor to apply additional quantization to your fine-tuned
sparse model before inference for even greater performance benefits.:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM("path/to/your/sparse/model")
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
For more details on vLLM's capabilities and advanced configuration options, see the [official vLLM documentation](https://docs.vllm.ai/).
## Learn More
For details on available sparsity and quantization schemes, fine-tuning recipes, and usage examples, visit the official LLMCompressor repository:
[https://github.com/vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor)

View File

@@ -1,5 +0,0 @@
"""Integration entry point for the LLMCompressor plugin."""
from .plugin import LLMCompressorPlugin
__all__ = ["LLMCompressorPlugin"]

View File

@@ -1,40 +0,0 @@
"""
LLMCompressor and Sparse Finetuning config models.
"""
from typing import Any
from pydantic import BaseModel, Field
from typing_extensions import Annotated
class CompressionArgs(BaseModel):
"""Sparse Finetuning config for LLMCompressor."""
# Typing for recipe is set to Any due to:
# https://github.com/vllm-project/llm-compressor/issues/1319
recipe: Annotated[
Any,
Field(
description="The recipe containing the compression algorithms and hyperparameters to apply."
),
]
save_compressed: Annotated[
bool,
Field(
default=False,
description="Whether to save the compressed model after training.",
),
]
class LLMCompressorArgs(BaseModel):
"""LLMCompressor configuration BaseModel."""
llmcompressor: Annotated[
CompressionArgs,
Field(
description="Arguments enabling compression pathways through the LLM Compressor plugins"
),
]

View File

@@ -1,171 +0,0 @@
"""
Sparse Finetuning plugin for Axolotl — enables handling of sparse neural networks
by maintaining masks for zero weights during training.
"""
import logging
from functools import wraps
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
from llmcompressor import active_session, create_session
from llmcompressor.core import callbacks as session_callbacks
from llmcompressor.recipe import Recipe
from torch.nn import Module
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from axolotl.integrations.base import BasePlugin
P = ParamSpec("P") # Params for generic function signatures
R = TypeVar("R") # Return type for generic function signatures
LOG = logging.getLogger("axolotl.integrations.llm_compressor")
class LLMCompressorCallbackHandler(TrainerCallback):
"""
Trainer callback for Sparse Finetuning.
Maintains sparsity patterns during training by applying masks after optimization steps,
ensuring zero-weight updates are canceled out.
"""
def __init__(self, trainer: Trainer, recipe: Any):
"""
Initialize the Sparse Finetuning callback handler.
Args:
trainer (Trainer): Huggingface Trainer instance.
recipe (Recipe | dict): Sparse finetuning recipe to apply.
"""
super().__init__()
self.trainer = trainer
self.recipe = (
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
)
self.original_compute_loss = trainer.compute_loss
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
create_session()
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the beginning of training. Initializes the compression session.
Args:
args (TrainingArguments): Training arguments.
state (TrainerState): Trainer state.
control (TrainerControl): Trainer control.
"""
super().on_train_begin(args, state, control, **kwargs)
self.trainer.accelerator.wait_for_everyone()
active_session().initialize(
model=self.trainer.model,
optimizer=self.trainer.optimizer,
start=state.epoch,
recipe=self.recipe,
)
self.trainer.accelerator.wait_for_everyone()
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the beginning of a training step. Triggers batch_start callback.
"""
super().on_step_begin(args, state, control, **kwargs)
session_callbacks.batch_start()
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the end of a training step. Triggers optimizer and batch_end callbacks.
"""
super().on_step_end(args, state, control, **kwargs)
session_callbacks.optim_pre_step()
session_callbacks.optim_post_step()
session_callbacks.batch_end()
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the end of training. Finalizes the compression session.
"""
super().on_train_end(args, state, control, **kwargs)
active_session().finalize()
self.trainer.compute_loss_func = self.original_compute_loss
class LLMCompressorPlugin(BasePlugin):
"""
Sparse Finetuning plugin for Axolotl integration.
"""
def get_input_args(self) -> str:
"""
Returns the path to the plugin's argument definition.
Returns:
str: Dotted path to the LLMCompressorArgs class.
"""
return "axolotl.integrations.llm_compressor.args.LLMCompressorArgs"
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
"""
Adds Sparse Finetuning callback to the Trainer instance.
Args:
cfg (Any): Configuration object containing the sparse recipe.
trainer (Trainer): Huggingface Trainer instance.
Returns:
list: List containing the configured callback instances.
"""
LOG.info("Adding Sparse Finetuning callback to the trainer")
callback = LLMCompressorCallbackHandler(
trainer=trainer,
recipe=cfg.llmcompressor.recipe,
)
return [callback]
def compute_loss_wrapper(
compute_loss_func: Callable[Concatenate[Module, P], R],
) -> Callable[Concatenate[Module, P], R]:
"""
Wraps the loss computation function to trigger the loss_calculated callback.
Args:
compute_loss_func (Callable): Original loss computation function.
Returns:
Callable: Wrapped function that also invokes the loss_calculated callback.
"""
@wraps(compute_loss_func)
def compute_and_notify(model: Module, *args: P.args, **kwargs: P.kwargs) -> R:
loss = compute_loss_func(model, *args, **kwargs)
if active_session().lifecycle.initialized_ and model.training:
session_callbacks.loss_calculated(loss=loss)
return loss
return compute_and_notify

View File

@@ -1,40 +0,0 @@
"""Utilities for llmcompressor integration with axolotl."""
from typing import Union
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
)
from transformers import PreTrainedModel, Trainer
def save_compressed_model(
model: PreTrainedModel,
output_dir: Union[str, bytes],
trainer: Trainer,
safe_serialization: bool = False,
save_compressed: bool = False,
) -> None:
"""
Synchronize processes, apply compression hooks, and save the model.
Args:
model (PreTrainedModel): The model to be saved.
output_dir (str or bytes): Path where the model files will be written.
trainer (Trainer): Hugging Face Trainer for process synchronization.
safe_serialization (bool): Use safe serialization if True.
save_compressed (bool): Write compressed tensors if True.
"""
trainer.accelerator.wait_for_everyone()
# Only the main process writes the files
if not trainer.accelerator.is_main_process:
return
modify_save_pretrained(model)
model.save_pretrained(
output_dir,
safe_serialization=safe_serialization,
save_compressed=save_compressed,
skip_sparsity_compression_stats=not save_compressed,
)

View File

@@ -76,7 +76,8 @@ def register_ring_attn(
LOG.info( LOG.info(
"Enabling ring attention sequence parallelism: " "Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs" f"each sequence will be processed across {sequence_parallel_degree} GPUs "
f"using the {ring_attn_func.value} ring-flash-attn implementation"
) )
rank = dist.get_rank() rank = dist.get_rank()

View File

@@ -295,23 +295,8 @@ def save_trained_model(
trainer.model.save_pretrained( trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization cfg.output_dir, safe_serialization=safe_serialization
) )
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin
from axolotl.integrations.llm_compressor.utils import (
save_compressed_model,
)
save_compressed_model(
model=model,
output_dir=cfg.output_dir,
trainer=trainer,
safe_serialization=safe_serialization,
save_compressed=cfg.llmcompressor.save_compressed,
)
def create_model_card(cfg: DictDefault, trainer: Trainer): def create_model_card(cfg: DictDefault, trainer: Trainer):
""" """

View File

@@ -134,9 +134,10 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
"csv", data_files=f.name, split="train", streaming=True "csv", data_files=f.name, split="train", streaming=True
) )
else: else:
iter_ds = load_dataset( if is_local_main_process():
path, streaming=True, split=split, name=name, data_files=data_files iter_ds = load_dataset(
) path, streaming=True, split=split, name=name, data_files=data_files
)
if skip: if skip:
LOG.info(f"Skipping {skip} samples from the dataset") LOG.info(f"Skipping {skip} samples from the dataset")

View File

@@ -1,7 +1,5 @@
"""custom checkpointing utils""" """custom checkpointing utils"""
from functools import partial
from axolotl.utils.gradient_checkpointing.unsloth import ( from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer, Unsloth_Offloaded_Gradient_Checkpointer,
) )
@@ -11,10 +9,6 @@ def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply( return Unsloth_Offloaded_Gradient_Checkpointer.apply(
( decoder_layer.__self__,
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)
else decoder_layer.__self__
),
*args, *args,
) )

View File

@@ -139,22 +139,6 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
hasattr(model_config, "quantization_config") hasattr(model_config, "quantization_config")
and model_config.quantization_config and model_config.quantization_config
) )
# Detect compressed-tensors config
is_compressed_tensors_config = (
quant_config_exists
and model_config.quantization_config.get("quant_method") == "compressed-tensors"
)
if is_compressed_tensors_config:
if model_config.quantization_config.get("config_groups"):
LOG.warning(
"Found `config_groups` in a compressed-tensors config. "
"QAT integration with llmcompressor is not tested."
)
# Skip further quant checks for compressed-tensors
return
quant_config_method_is_gptq = ( quant_config_method_is_gptq = (
quant_config_exists quant_config_exists
and "quant_method" in model_config.quantization_config and "quant_method" in model_config.quantization_config

View File

@@ -18,6 +18,7 @@ from pydantic import (
) )
from transformers.utils.import_utils import is_torch_npu_available from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.distributed import is_main_process
from axolotl.utils.schemas.datasets import ( from axolotl.utils.schemas.datasets import (
DatasetConfig, DatasetConfig,
DPODataset, DPODataset,
@@ -718,9 +719,10 @@ class AxolotlInputConfig(
and data.get("eval_sample_packing") is None and data.get("eval_sample_packing") is None
and not data.get("eval_table_size") and not data.get("eval_table_size")
): ):
LOG.info( if is_main_process():
"explicitly setting `eval_sample_packing` to match `sample_packing`" LOG.info(
) "explicitly setting `eval_sample_packing` to match `sample_packing`"
)
data["eval_sample_packing"] = True data["eval_sample_packing"] = True
if ( if (
@@ -1177,14 +1179,15 @@ class AxolotlInputConfig(
# TODO: monkeypatch / callback to average losses correctly across SP ranks # TODO: monkeypatch / callback to average losses correctly across SP ranks
# / fix gradient scaling across SP ranks. Losses, grads should be scaled # / fix gradient scaling across SP ranks. Losses, grads should be scaled
# according to the proportion of non-padding tokens per rank. # according to the proportion of non-padding tokens per rank.
LOG.warning( if is_main_process():
"Sequence parallelism (SP) is enabled with " LOG.warning(
f"sequence_parallel_degree={self.sequence_parallel_degree}. " "Sequence parallelism (SP) is enabled with "
"Please note that logged losses may differ slightly to the non-SP " f"sequence_parallel_degree={self.sequence_parallel_degree}. "
"losses due to transformers Trainer implementation details. " "Please note that logged losses may differ slightly to the non-SP "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " "losses due to transformers Trainer implementation details. "
"for more details." "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
) "for more details."
)
return self return self

View File

@@ -528,13 +528,6 @@ def setup_torch_compile_env(cfg):
def setup_deepspeed_env(cfg, stage=None): def setup_deepspeed_env(cfg, stage=None):
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
from axolotl.utils.distributed import distributed_state
if distributed_state and distributed_state.initialized:
raise RuntimeError(
"Distributed State already initialized before Deepspeed setup"
)
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if stage: if stage:

View File

@@ -1,106 +0,0 @@
"""
E2E smoke tests for LLMCompressorPlugin integration
"""
from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import (
check_model_output_exists,
require_llmcompressor,
require_torch_2_4_1,
)
MODELS = [
"nm-testing/llama2.c-stories42M-pruned2.4-compressed",
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed",
]
@pytest.mark.parametrize(
"base_model", MODELS, ids=["no-checkpoint-recipe", "with-checkpoint-recipe"]
)
@pytest.mark.parametrize(
"save_compressed", [True, False], ids=["save_compressed", "save_uncompressed"]
)
@require_llmcompressor
class TestLLMCompressorIntegration:
"""
e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin
"""
@require_torch_2_4_1
def test_llmcompressor_plugin(
self, temp_dir, base_model: str, save_compressed: bool
):
# core cfg
cfg = DictDefault(
{
"base_model": base_model,
"plugins": ["axolotl.integrations.llm_compressor.LLMCompressorPlugin"],
"sequence_len": 1024,
"val_set_size": 0.05,
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 1e-5,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
"llmcompressor": {
"recipe": {
"finetuning_stage": {
"finetuning_modifiers": {
"ConstantPruningModifier": {
"targets": [
"re:.*q_proj.weight",
"re:.*k_proj.weight",
"re:.*v_proj.weight",
"re:.*o_proj.weight",
"re:.*gate_proj.weight",
"re:.*up_proj.weight",
"re:.*down_proj.weight",
],
"start": 0,
},
},
},
},
"save_compressed": save_compressed,
},
}
)
prepare_plugins(cfg)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
_check_llmcompressor_model_outputs(temp_dir, save_compressed)
def _check_llmcompressor_model_outputs(temp_dir, save_compressed):
if save_compressed:
assert (Path(temp_dir) / "recipe.yaml").exists()
from compressed_tensors import ModelCompressor
from compressed_tensors.config import Sparse24BitMaskConfig
compressor = ModelCompressor.from_pretrained(temp_dir)
assert compressor is not None
assert isinstance(compressor.sparsity_config, Sparse24BitMaskConfig)

View File

@@ -4,14 +4,11 @@ GRPO test suite
import os import os
import random import random
import shutil
import subprocess # nosec B404 import subprocess # nosec B404
import sys import sys
import tempfile
import time import time
from pathlib import Path from pathlib import Path
import psutil
import pytest import pytest
import requests import requests
import yaml import yaml
@@ -24,8 +21,8 @@ from tests.e2e.utils import require_vllm
def start_vllm( def start_vllm(
model: str, env: dict, wait: int | None = None, quiet=False, **kwargs model: str, env: dict | None = None, wait: int | None = None, quiet=False, **kwargs
) -> subprocess.Popen: ) -> int:
""" """
helper function to start the VLLM server in the background, mostly for testing purposes helper function to start the VLLM server in the background, mostly for testing purposes
""" """
@@ -49,41 +46,10 @@ def start_vllm(
# print out the command to be executed # print out the command to be executed
print(" ".join(cmd)) print(" ".join(cmd))
vllm_logging_json = Path(tempfile.mkdtemp()) / "vllm_logging.json"
with open(vllm_logging_json, "w", encoding="utf-8") as temp_file:
temp_file.write(
"""{
"formatters": {
"json": {
"class": "pythonjsonlogger.jsonlogger.JsonFormatter"
}
},
"handlers": {
"file": {
"class": "logging.FileHandler",
"formatter": "json",
"level": "DEBUG",
"filename": "/tmp/vllm.log",
"mode": "a"
}
},
"loggers": {
"vllm": {
"handlers": ["file"],
"level": "DEBUG",
"propagate": false
}
},
"version": 1
}"""
)
cmd_env = env.copy()
cmd_env.update({"VLLM_LOGGING_CONFIG_PATH": vllm_logging_json})
# start `trl vllm-serve` command in the background and capture the process id # start `trl vllm-serve` command in the background and capture the process id
process = subprocess.Popen( # pylint: disable=consider-using-with process = subprocess.Popen( # pylint: disable=consider-using-with
cmd, cmd,
env=cmd_env, env=env,
stdout=subprocess.DEVNULL if quiet else subprocess.PIPE, stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,
stderr=subprocess.DEVNULL if quiet else subprocess.PIPE, stderr=subprocess.DEVNULL if quiet else subprocess.PIPE,
) # nosec B603 ) # nosec B603
@@ -92,51 +58,32 @@ def start_vllm(
print(f"VLLM server process started (PID: {process.pid})") print(f"VLLM server process started (PID: {process.pid})")
# wait until the http server is ready, even if it 404s, but timeout after 60 seconds # wait until the http server is ready, even if it 404s, but timeout after 60 seconds
period_seconds = 5
started = False started = False
if wait and host and port: if wait and host and port:
for i in range(0, int(wait), period_seconds): for _ in range(int(wait)):
try: try:
response = requests.get(f"http://{host}:{port}", timeout=1) response = requests.get(f"http://{host}:{port}", timeout=1)
print(f"{i}: VLLM server (status: {response.status_code})")
if int(response.status_code) in [200, 404]: if int(response.status_code) in [200, 404]:
started = True started = True
break break
except requests.exceptions.RequestException as exc: except requests.exceptions.RequestException:
print(f"{i}: VLLM server failed to start: {str(exc)}") pass
# also check if the process.pid is still running # also check if the process.pid is still running
if not process.poll() is None: if not process.poll() is None:
break break
time.sleep(period_seconds) time.sleep(1)
if wait and not started: if wait and not started:
print( print(
f"VLLM server process did not start within {wait} seconds. Please check your server logs." f"VLLM server process did not start within {wait} seconds. Please check your server logs."
) )
recursive_kill(process) process.kill()
with open("/tmp/vllm.log", "r", encoding="utf-8") as log_file:
print(log_file.read())
shutil.rmtree("/tmp/vllm.log")
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.") raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
# return the process # return the process id
return process return process.pid
def recursive_kill(process: subprocess.Popen):
"""
Recursively kill a process and its children
"""
process = psutil.Process(process.pid)
for child in psutil.Process(process.pid).children(recursive=True):
child.terminate()
child.kill()
os.kill(child.pid, 9)
process.terminate()
process.kill()
os.kill(process.pid, 9)
class TestGRPO: class TestGRPO:
@@ -227,17 +174,16 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy() current_env = os.environ.copy()
env = { env = {
"NCCL_P2P_LEVEL": "NVL", "NCCL_P2P_LEVEL": "LOC",
**current_env, **current_env,
"CUDA_VISIBLE_DEVICES": "1", "CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1", "VLLM_USE_V1": "0",
# "VLLM_USE_V1": "0",
} }
vllm_process = start_vllm( vllm_process_id = start_vllm(
cfg.base_model, cfg.base_model,
env=env, env=env,
quiet=True, quiet=True,
wait=300, wait=120,
gpu_memory_utilization=0.15, gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len, max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching, enable_prefix_caching=cfg.vllm.enable_prefix_caching,
@@ -256,14 +202,10 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"--main-process-port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
], ],
env={ env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
"NCCL_P2P_LEVEL": "NVL",
"NCCL_DEBUG": "INFO",
**current_env,
},
) )
finally: finally:
recursive_kill(vllm_process) os.kill(vllm_process_id, 9)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_gpus", "num_gpus",
@@ -320,17 +262,16 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy() current_env = os.environ.copy()
env = { env = {
"NCCL_P2P_LEVEL": "NVL", # nccl can be brittle, assume P2P isn't reliable "NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env, **current_env,
"CUDA_VISIBLE_DEVICES": "1", "CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1", "VLLM_USE_V1": "0",
# "VLLM_USE_V1": "0",
} }
vllm_process = start_vllm( vllm_process_id = start_vllm(
cfg.base_model, cfg.base_model,
env=env, env=env,
quiet=True, quiet=True,
wait=300, wait=120,
gpu_memory_utilization=0.15, gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len, max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching, enable_prefix_caching=cfg.vllm.enable_prefix_caching,
@@ -349,11 +290,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"--main-process-port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
], ],
env={ env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
"NCCL_P2P_LEVEL": "NVL",
"NCCL_DEBUG": "INFO",
**current_env,
},
) )
finally: finally:
recursive_kill(vllm_process) os.kill(vllm_process_id, 9)

View File

@@ -1,77 +0,0 @@
"""
E2E tests for activation checkpointing
"""
import pytest
import transformers
from torch.utils.checkpoint import checkpoint
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
@pytest.fixture()
def fix_checkpoint_after_test():
yield
transformers.modeling_utils.checkpoint = checkpoint
class TestActivationCheckpointing:
"""
E2E tests for activation checkpointing
"""
def test_activation_checkpointing_offload(
self,
temp_dir,
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"eos_token": "<|im_end|>",
},
"datasets": [
{
"chat_template": "chatml",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": "offload",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -99,7 +99,6 @@ class TestMixtral(unittest.TestCase):
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -131,6 +131,11 @@ class TestConfigValidation:
# Mock the ring_flash_attn module # Mock the ring_flash_attn module
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock()) monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
# Mock the is_main_process function to return True
monkeypatch.setattr(
"axolotl.utils.schemas.config.is_main_process", lambda: True
)
@pytest.fixture @pytest.fixture
def base_cfg(self): def base_cfg(self):
"""Create a base configuration for testing.""" """Create a base configuration for testing."""

View File

@@ -109,24 +109,6 @@ def require_vllm(test_case):
)(test_case) )(test_case)
def require_llmcompressor(test_case):
"""
Decorator marking a test that requires a llmcompressor to be installed
"""
def is_llmcompressor_installed():
try:
import llmcompressor # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
return unittest.skipUnless(
is_llmcompressor_installed(), "test requires a llmcompressor to be installed"
)(test_case)
def is_hopper(): def is_hopper():
compute_capability = torch.cuda.get_device_capability() compute_capability = torch.cuda.get_device_capability()
return compute_capability == (9, 0) return compute_capability == (9, 0)