Compare commits
13 Commits
sp-rl
...
lora-kerne
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
700409be6f | ||
|
|
64d8035f50 | ||
|
|
5249e98058 | ||
|
|
3877c5c69d | ||
|
|
adb593abac | ||
|
|
a0117c9bce | ||
|
|
e6cfb093d2 | ||
|
|
7abc71dc0b | ||
|
|
45bf634d17 | ||
|
|
80ba4b69f1 | ||
|
|
0bfa180f7d | ||
|
|
9e22c4ca6a | ||
|
|
990b5896bc |
8
.github/workflows/base.yml
vendored
8
.github/workflows/base.yml
vendored
@@ -52,6 +52,12 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: nightly
|
pytorch: nightly
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
- cuda: "128"
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: next
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -73,7 +79,7 @@ jobs:
|
|||||||
uses: docker/build-push-action@v4
|
uses: docker/build-push-action@v4
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }}
|
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
|
|||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
|||||||
38
docker/Dockerfile-base-next
Normal file
38
docker/Dockerfile-base-next
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
ARG CUDA_VERSION="12.8.1"
|
||||||
|
ARG CUDNN_VERSION="8"
|
||||||
|
ARG UBUNTU_VERSION="22.04"
|
||||||
|
ARG MAX_JOBS=4
|
||||||
|
|
||||||
|
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||||
|
|
||||||
|
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||||
|
|
||||||
|
ARG PYTHON_VERSION="3.11"
|
||||||
|
ARG PYTORCH_VERSION="next"
|
||||||
|
ARG CUDA="128"
|
||||||
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
|
|
||||||
|
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||||
|
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||||
|
|
||||||
|
RUN apt-get update \
|
||||||
|
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& wget \
|
||||||
|
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||||
|
&& mkdir /root/.conda \
|
||||||
|
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||||
|
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||||
|
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||||
|
|
||||||
|
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||||
|
python3 -m pip install --no-cache-dir -U torch==2.7.0 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
|
||||||
|
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||||
|
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||||
|
|
||||||
|
RUN git lfs install --skip-repo && \
|
||||||
|
pip3 install awscli && \
|
||||||
|
pip3 install -U --no-cache-dir pydantic==2.10.6
|
||||||
@@ -510,7 +510,8 @@ train_on_inputs: false
|
|||||||
# Note that training loss may have an oscillating pattern with this enabled.
|
# Note that training loss may have an oscillating pattern with this enabled.
|
||||||
group_by_length: false
|
group_by_length: false
|
||||||
|
|
||||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
|
||||||
|
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
# additional kwargs to pass to the trainer for gradient checkpointing
|
||||||
# gradient_checkpointing_kwargs:
|
# gradient_checkpointing_kwargs:
|
||||||
@@ -686,10 +687,9 @@ ddp_broadcast_buffers:
|
|||||||
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
||||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||||
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
||||||
sequence_parallel_degree: 4 # Set to the number of GPUs to split sequences across
|
sequence_parallel_degree:
|
||||||
flash_attention: true # SP requires flash attention
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||||
micro_batch_size: 1 # SP requires this is set to 1
|
# Must evenly divide the number of KV heads in your model.
|
||||||
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
|
||||||
heads_k_stride: 1
|
heads_k_stride: 1
|
||||||
|
|
||||||
# Path to torch distx for optim 'adamw_anyprecision'
|
# Path to torch distx for optim 'adamw_anyprecision'
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ We currently support several common model architectures, including (but not limi
|
|||||||
- `qwen2`
|
- `qwen2`
|
||||||
- `gemma`
|
- `gemma`
|
||||||
- `gemma2`
|
- `gemma2`
|
||||||
|
- `gemma3`
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
|
||||||
|
|||||||
@@ -23,10 +23,9 @@ Use sequence parallelism when:
|
|||||||
To enable sequence parallelism, add the following to your configuration file:
|
To enable sequence parallelism, add the following to your configuration file:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
# Set to a divisor (> 1) of the number of GPUs available
|
||||||
flash_attention: true # SP requires flash attention
|
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||||
micro_batch_size: 1 # SP requires this is set to 1
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||||
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
|
||||||
heads_k_stride: 1
|
heads_k_stride: 1
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -67,16 +66,15 @@ sequence_len: 8192
|
|||||||
...
|
...
|
||||||
|
|
||||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||||
flash_attention: true # SP requires flash attention
|
flash_attention: true # Required with sequence parallelism
|
||||||
micro_batch_size: 1 # SP requires this is set to 1
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||||
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
|
||||||
heads_k_stride: 1
|
heads_k_stride: 1
|
||||||
|
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
This will train the Llama 3 8B model with 8192 context length, with each sequence split
|
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
||||||
into 4 subsequences of length 2048 across 4 GPUs.
|
into 2 subsequences of length 4096 across 2 GPUs.
|
||||||
|
|
||||||
## Sample Packing with Sequence Parallelism
|
## Sample Packing with Sequence Parallelism
|
||||||
|
|
||||||
@@ -88,14 +86,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
|
|||||||
|
|
||||||
## Effect on Batch Size
|
## Effect on Batch Size
|
||||||
|
|
||||||
First, note that sequence parallelism supports only the case where `micro_batch_size: 1`.
|
|
||||||
|
|
||||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||||
|
|
||||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||||
- The number of batches processed per step decreases
|
- The number of batches processed per step decreases
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
- With 8 GPUs and no sequence parallelism: 8 different batches are processed per step
|
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||||
- If your per-GPU `micro_batch_size` is 1, the global batch size decreases from 8 to 2
|
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||||
|
|||||||
68
examples/gemma3/gemma-3-4b-qlora.yml
Normal file
68
examples/gemma3/gemma-3-4b-qlora.yml
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
base_model: google/gemma-3-4b-it
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
# gemma3 doesn't seem to play nice with ddp
|
||||||
|
ddp_find_unused_parameters: true
|
||||||
|
|
||||||
|
chat_template: gemma3
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.01
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
@@ -2,6 +2,8 @@ base_model: google/gemma-3-4b-it
|
|||||||
processor_type: AutoProcessor
|
processor_type: AutoProcessor
|
||||||
strict: false
|
strict: false
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
skip_prepare_dataset: true
|
skip_prepare_dataset: true
|
||||||
remove_unused_columns: false
|
remove_unused_columns: false
|
||||||
@@ -20,7 +22,7 @@ dataset_prepared_path: last_run_prepared
|
|||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
output_dir: ./outputs/out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
adapter: lora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
@@ -82,6 +82,3 @@ deepspeed:
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
fsdp:
|
fsdp:
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
pad_token: "<|end_of_text|>"
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ transformers==4.50.3
|
|||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.5.2
|
accelerate==1.5.2
|
||||||
datasets==3.5.0
|
datasets==3.5.0
|
||||||
deepspeed==0.16.4
|
deepspeed==0.15.4
|
||||||
trl==0.16.0
|
trl==0.16.0
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -112,7 +112,7 @@ extras_require = {
|
|||||||
"yunchang==0.6.0",
|
"yunchang==0.6.0",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.16.4",
|
"deepspeed==0.15.4",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
],
|
],
|
||||||
"mamba-ssm": [
|
"mamba-ssm": [
|
||||||
|
|||||||
@@ -4,4 +4,4 @@ import pkgutil
|
|||||||
|
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||||
|
|
||||||
__version__ = "0.8.0.dev0"
|
__version__ = "0.8.0"
|
||||||
|
|||||||
@@ -256,7 +256,7 @@ def do_cli(
|
|||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parsed_cfg = load_cfg(config, inference=True, **kwargs)
|
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
||||||
parsed_cfg.sample_packing = False
|
parsed_cfg.sample_packing = False
|
||||||
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
|
|||||||
@@ -74,8 +74,10 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
|||||||
load_in_8bit=False,
|
load_in_8bit=False,
|
||||||
load_in_4bit=False,
|
load_in_4bit=False,
|
||||||
flash_attention=False,
|
flash_attention=False,
|
||||||
|
sequence_parallel_degree=None,
|
||||||
deepspeed=None,
|
deepspeed=None,
|
||||||
fsdp=None,
|
fsdp=None,
|
||||||
|
fsdp_config=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -86,13 +88,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
|||||||
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
|
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
|
||||||
)
|
)
|
||||||
|
|
||||||
parsed_cfg.load_in_4bit = False
|
|
||||||
parsed_cfg.load_in_8bit = False
|
|
||||||
parsed_cfg.flash_attention = False
|
|
||||||
parsed_cfg.deepspeed = None
|
|
||||||
parsed_cfg.fsdp = None
|
|
||||||
parsed_cfg.fsdp_config = None
|
|
||||||
|
|
||||||
do_merge_lora(cfg=parsed_cfg)
|
do_merge_lora(cfg=parsed_cfg)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1043,10 +1043,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rpo_alpha is not None:
|
if self.cfg.rpo_alpha is not None:
|
||||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||||
|
|
||||||
training_args_kwargs["sequence_parallel_degree"] = (
|
|
||||||
self.cfg.sequence_parallel_degree
|
|
||||||
)
|
|
||||||
|
|
||||||
training_args_cls = None
|
training_args_cls = None
|
||||||
blocklist_args_kwargs = []
|
blocklist_args_kwargs = []
|
||||||
if self.cfg.rl == "simpo":
|
if self.cfg.rl == "simpo":
|
||||||
@@ -1165,7 +1161,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["dataset_tags"] = [
|
dpo_trainer_kwargs["dataset_tags"] = [
|
||||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
]
|
]
|
||||||
|
|
||||||
dpo_trainer = trainer_cls(
|
dpo_trainer = trainer_cls(
|
||||||
*trainer_cls_args,
|
*trainer_cls_args,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
@@ -1183,3 +1178,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer.add_callback(callback)
|
dpo_trainer.add_callback(callback)
|
||||||
|
|
||||||
return dpo_trainer
|
return dpo_trainer
|
||||||
|
|
||||||
|
|
||||||
|
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
||||||
|
"""
|
||||||
|
HF Factory class for PPO Trainer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_callbacks(self):
|
||||||
|
callbacks = super().get_callbacks()
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
|
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
def build(self, total_num_steps):
|
||||||
|
# build PPOConfig
|
||||||
|
pass
|
||||||
|
|||||||
@@ -3,16 +3,16 @@
|
|||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
from .base import AxolotlTrainer
|
||||||
from axolotl.core.trainers.dpo import AxolotlDPOTrainer
|
from .dpo.trainer import AxolotlDPOTrainer
|
||||||
from axolotl.core.trainers.grpo import AxolotlGRPOTrainer
|
from .grpo.trainer import AxolotlGRPOTrainer
|
||||||
from axolotl.core.trainers.mamba import AxolotlMambaTrainer
|
from .mamba import AxolotlMambaTrainer
|
||||||
from axolotl.core.trainers.relora import ReLoRATrainer
|
from .relora import ReLoRATrainer
|
||||||
from axolotl.core.trainers.trl import (
|
from .trl import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
AxolotlKTOTrainer,
|
AxolotlKTOTrainer,
|
||||||
AxolotlORPOTrainer,
|
AxolotlORPOTrainer,
|
||||||
AxolotlPPOTrainer,
|
|
||||||
AxolotlPRMTrainer,
|
AxolotlPRMTrainer,
|
||||||
AxolotlRewardTrainer,
|
AxolotlRewardTrainer,
|
||||||
|
TRLPPOTrainer,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,11 +8,10 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Literal
|
from typing import Literal
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from torch.utils.data import (
|
from torch.utils.data import (
|
||||||
BatchSampler,
|
BatchSampler,
|
||||||
@@ -26,8 +25,12 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
|||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from axolotl.core.trainers.handlers import SequenceParallelHandler
|
from axolotl.core.trainers.mixins import (
|
||||||
from axolotl.core.trainers.mixins import TrainerMixins
|
OptimizerMixin,
|
||||||
|
RngLoaderMixin,
|
||||||
|
SchedulerMixin,
|
||||||
|
SequenceParallelMixin,
|
||||||
|
)
|
||||||
from axolotl.core.trainers.utils import (
|
from axolotl.core.trainers.utils import (
|
||||||
sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
@@ -37,7 +40,9 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(TrainerMixins, Trainer):
|
class AxolotlTrainer(
|
||||||
|
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
|
||||||
|
):
|
||||||
"""Extend the base Trainer for axolotl helpers"""
|
"""Extend the base Trainer for axolotl helpers"""
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
@@ -63,7 +68,9 @@ class AxolotlTrainer(TrainerMixins, Trainer):
|
|||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
self.sequence_parallel_handler = SequenceParallelHandler(self.args)
|
# Initialize sequence parallelism if enabled
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
self._setup_sequence_parallel()
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True, dataloader=None):
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
if self.args.torch_compile:
|
if self.args.torch_compile:
|
||||||
@@ -124,7 +131,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
|
|||||||
|
|
||||||
# Determine the base sampler first
|
# Determine the base sampler first
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
base_sampler = self.sequence_parallel_handler._get_train_sampler(self.train_dataset)
|
base_sampler = self._sp_get_train_sampler(self.train_dataset)
|
||||||
elif self.args.curriculum_sampling:
|
elif self.args.curriculum_sampling:
|
||||||
base_sampler = SequentialSampler(self.train_dataset)
|
base_sampler = SequentialSampler(self.train_dataset)
|
||||||
elif use_sample_packing:
|
elif use_sample_packing:
|
||||||
@@ -160,7 +167,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
|
|||||||
|
|
||||||
# Determine the base sampler
|
# Determine the base sampler
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
base_sampler = self.sequence_parallel_handler._get_eval_sampler(eval_dataset)
|
base_sampler = self._sp_get_eval_sampler(eval_dataset)
|
||||||
elif use_multipack:
|
elif use_multipack:
|
||||||
base_sampler = SequentialSampler(eval_dataset)
|
base_sampler = SequentialSampler(eval_dataset)
|
||||||
else:
|
else:
|
||||||
@@ -232,10 +239,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
|
|||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
# Otherwise prepare with accelerator
|
# Otherwise prepare with accelerator
|
||||||
dataloader = self.accelerator.prepare_data_loader(dataloader)
|
return self.accelerator.prepare_data_loader(dataloader)
|
||||||
|
|
||||||
return dataloader
|
|
||||||
|
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
"""Get dataloader for training"""
|
"""Get dataloader for training"""
|
||||||
@@ -344,57 +348,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
|
|||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
return DataLoader(bench_dataset, **dataloader_params)
|
return DataLoader(bench_dataset, **dataloader_params)
|
||||||
|
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||||
def training_step(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
inputs: dict[str, torch.Tensor | Any],
|
|
||||||
num_items_in_batch: int | None = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Perform a training step on a batch of inputs. Overrides the
|
|
||||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
|
||||||
enabled.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Model to perform training step for.
|
|
||||||
inputs: Dictionary mapping of inputs.
|
|
||||||
num_items_in_batch: The number of items in the batch.
|
|
||||||
"""
|
|
||||||
# Set up sequence parallelism for this step if enabled
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
|
|
||||||
|
|
||||||
# Proceed with normal training step
|
|
||||||
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
|
|
||||||
|
|
||||||
def prediction_step(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
inputs: dict[str, torch.Tensor | Any],
|
|
||||||
prediction_loss_only: bool,
|
|
||||||
ignore_keys: list[str] | None = None,
|
|
||||||
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
|
||||||
"""
|
|
||||||
Perform a prediction step on a batch of inputs. Overrides the
|
|
||||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
|
||||||
enabled.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Model to perform prediction step for.
|
|
||||||
inputs: Dictionary mapping of inputs.
|
|
||||||
prediction_loss_only: Whether to return only the loss.
|
|
||||||
ignore_keys: Keys to ignore in the inputs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (loss, logits, labels).
|
|
||||||
"""
|
|
||||||
# Set up sequence parallelism for this prediction step if enabled
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
|
|
||||||
|
|
||||||
# Proceed with normal prediction step
|
|
||||||
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
"""DPO Specific Strategy for training"""
|
"""
|
||||||
|
DPO Specific Strategy for training
|
||||||
|
"""
|
||||||
|
|
||||||
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||||
|
|
||||||
|
|
||||||
class DPOStrategy:
|
class DPOStrategy:
|
||||||
"""Strategy for DPO training"""
|
"""
|
||||||
|
Strategy for DPO training
|
||||||
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_trainer_class(cls):
|
def get_trainer_class(cls):
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
"""Axolotl specific DPO args"""
|
"""
|
||||||
|
Axolotl specific DPO args
|
||||||
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@@ -9,4 +11,6 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||||
"""DPO config for DPO training"""
|
"""
|
||||||
|
DPO config for DPO training
|
||||||
|
"""
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
"""DPO trainer for axolotl"""
|
"""
|
||||||
|
DPO trainer for axolotl
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
@@ -10,8 +13,7 @@ from transformers import Trainer
|
|||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
from axolotl.core.trainers.handlers import SequenceParallelHandler
|
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||||
from axolotl.core.trainers.mixins import TrainerMixins
|
|
||||||
from axolotl.core.trainers.utils import (
|
from axolotl.core.trainers.utils import (
|
||||||
sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
@@ -21,18 +23,18 @@ if is_sagemaker_mp_enabled():
|
|||||||
import smdistributed.modelparallel.torch as smp
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
|
||||||
class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
|
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
||||||
"""Extend the base DPOTrainer for axolotl helpers"""
|
"""
|
||||||
|
Extend the base DPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
tag_names = ["axolotl", "dpo"]
|
tag_names = ["axolotl", "dpo"]
|
||||||
|
|
||||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.dataset_tags = dataset_tags
|
self.dataset_tags = dataset_tags
|
||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
self.model_accepts_loss_kwargs = False
|
self.model_accepts_loss_kwargs = False
|
||||||
self.sequence_parallel_handler = SequenceParallelHandler(args=self.args)
|
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
@@ -86,7 +88,7 @@ class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
|
|||||||
max_prompt_length,
|
max_prompt_length,
|
||||||
max_completion_length,
|
max_completion_length,
|
||||||
add_special_tokens,
|
add_special_tokens,
|
||||||
) -> dict:
|
) -> Dict:
|
||||||
res = DPOTrainer.tokenize_row(
|
res = DPOTrainer.tokenize_row(
|
||||||
features,
|
features,
|
||||||
processing_class,
|
processing_class,
|
||||||
@@ -115,9 +117,10 @@ class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
|
|||||||
def training_step(
|
def training_step(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
inputs: dict[str, torch.Tensor | Any | None],
|
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||||
num_items_in_batch=None,
|
num_items_in_batch=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
self.sequence_parallel_handler.prepare_for_training_step(self, inputs)
|
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||||
|
gc.collect()
|
||||||
return super().training_step(model, inputs, num_items_in_batch)
|
torch.cuda.empty_cache()
|
||||||
|
return loss
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
"""Axolotl GRPO trainer"""
|
"""
|
||||||
|
Axolotl GRPO trainer
|
||||||
|
"""
|
||||||
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
@@ -6,14 +8,16 @@ from accelerate.utils import is_deepspeed_available, is_peft_model
|
|||||||
from trl import GRPOTrainer
|
from trl import GRPOTrainer
|
||||||
from trl.extras.profiling import profiling_decorator
|
from trl.extras.profiling import profiling_decorator
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import TrainerMixins
|
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||||
|
|
||||||
if is_deepspeed_available():
|
if is_deepspeed_available():
|
||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
|
|
||||||
class AxolotlGRPOTrainer(TrainerMixins, GRPOTrainer):
|
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||||
"""Extend the base GRPOTrainer for axolotl helpers"""
|
"""
|
||||||
|
Extend the base GRPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
_tag_names = ["trl", "grpo", "axolotl"]
|
_tag_names = ["trl", "grpo", "axolotl"]
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
"""Init for trainer handlers"""
|
|
||||||
|
|
||||||
from axolotl.core.trainers.handlers.sequence_parallel import SequenceParallelHandler
|
|
||||||
@@ -1,123 +0,0 @@
|
|||||||
"""Handler class for sequence parallel trainer logic"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.utils.data import DistributedSampler
|
|
||||||
|
|
||||||
|
|
||||||
class SequenceParallelHandler:
|
|
||||||
"""
|
|
||||||
Handler class that encapsulates sequence parallelism functionality.
|
|
||||||
This replaces the SequenceParallelMixin with a composition-based approach.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, args=None):
|
|
||||||
"""
|
|
||||||
Initialize the sequence parallel handler.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args: The arguments object containing sequence parallelism settings.
|
|
||||||
"""
|
|
||||||
self.args = args
|
|
||||||
self.ring_attn_group = None
|
|
||||||
|
|
||||||
# Set up sequence parallelism if enabled
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
self._setup_sequence_parallel()
|
|
||||||
|
|
||||||
def _setup_sequence_parallel(self):
|
|
||||||
"""Set up sequence parallelism environment."""
|
|
||||||
from ring_flash_attn import update_ring_flash_attn_params
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
|
||||||
|
|
||||||
self.update_ring_flash_attn_params = update_ring_flash_attn_params
|
|
||||||
self.ring_attn_group = get_ring_attn_group()
|
|
||||||
|
|
||||||
def create_sequence_parallel_sampler(
|
|
||||||
self,
|
|
||||||
dataset,
|
|
||||||
shuffle=True,
|
|
||||||
is_eval=False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Helper method to create sampler for sequence parallelism (SP).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: Dataset to sample from.
|
|
||||||
shuffle: Whether to shuffle the dataset.
|
|
||||||
is_eval: Whether we are creating a sampler for evaluation or training.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Distributed sampler.
|
|
||||||
"""
|
|
||||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
|
||||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
|
||||||
|
|
||||||
return DistributedSampler(
|
|
||||||
dataset,
|
|
||||||
num_replicas=num_sp_groups,
|
|
||||||
rank=sp_group_id,
|
|
||||||
seed=self.args.seed if shuffle else None,
|
|
||||||
shuffle=shuffle,
|
|
||||||
drop_last=not is_eval,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_train_sampler(self, dataset):
|
|
||||||
"""
|
|
||||||
Get a training sampler configured for sequence parallelism.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: The training dataset.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured sequence parallel sampler.
|
|
||||||
"""
|
|
||||||
return self.create_sequence_parallel_sampler(
|
|
||||||
dataset,
|
|
||||||
shuffle=not self.args.curriculum_sampling,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_eval_sampler(self, eval_dataset):
|
|
||||||
"""
|
|
||||||
Get an evaluation sampler configured for sequence parallelism.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
eval_dataset: The evaluation dataset.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured sequence parallel sampler.
|
|
||||||
"""
|
|
||||||
return self.create_sequence_parallel_sampler(
|
|
||||||
eval_dataset, shuffle=False, is_eval=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def _update_ring_flash_attn_params(self, inputs):
|
|
||||||
"""
|
|
||||||
Calculate the cu_seqlens for the current forward pass and pass the value to
|
|
||||||
the substituted ring_flash_attn.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs: Current batch of inputs.
|
|
||||||
"""
|
|
||||||
# At this point, inputs should already be partitioned by the sequence
|
|
||||||
# parallel data collator
|
|
||||||
batch_size = inputs["input_ids"].shape[0]
|
|
||||||
seq_len = inputs["input_ids"].shape[1]
|
|
||||||
packed_seq_lens = [seq_len] * batch_size
|
|
||||||
|
|
||||||
# Calculate the full sequence length across all GPUs in this SP group
|
|
||||||
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
|
||||||
|
|
||||||
cu_seqlens = torch.cumsum(
|
|
||||||
torch.tensor(
|
|
||||||
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
|
||||||
),
|
|
||||||
dim=-1,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
cu_seqlens = F.pad(
|
|
||||||
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
|
||||||
)
|
|
||||||
|
|
||||||
self.update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
|
||||||
@@ -3,12 +3,7 @@
|
|||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins.optimizer import OptimizerMixin
|
from .optimizer import OptimizerMixin
|
||||||
from axolotl.core.trainers.mixins.rng_state_loader import RngLoaderMixin
|
from .rng_state_loader import RngLoaderMixin
|
||||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
from .scheduler import SchedulerMixin
|
||||||
|
from .sequence_parallel import SequenceParallelMixin
|
||||||
|
|
||||||
class TrainerMixins(
|
|
||||||
OptimizerMixin, RngLoaderMixin, SchedulerMixin
|
|
||||||
):
|
|
||||||
"""Stub class combining all mixins for Axolotl trainers."""
|
|
||||||
|
|||||||
@@ -21,7 +21,9 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class RngLoaderMixin(Trainer):
|
class RngLoaderMixin(Trainer):
|
||||||
"""Mixin for method override to load RNG states from a checkpoint"""
|
"""
|
||||||
|
mixin for method override to load RNG states from a checkpoint
|
||||||
|
"""
|
||||||
|
|
||||||
def _load_rng_state(self, checkpoint):
|
def _load_rng_state(self, checkpoint):
|
||||||
# Load RNG states from `checkpoint`
|
# Load RNG states from `checkpoint`
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Module for Axolotl trainer sequence parallelism mixin"""
|
"""Module for Axolotl trainer sequence parallelism mixin"""
|
||||||
# TODO(Dan): remove
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -8,6 +7,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from torch import nn
|
||||||
from torch.utils.data import DistributedSampler, Sampler
|
from torch.utils.data import DistributedSampler, Sampler
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||||
@@ -71,12 +71,12 @@ class SequenceParallelMixin:
|
|||||||
drop_last=not is_eval,
|
drop_last=not is_eval,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_train_sampler(self, dataset) -> Sampler | None:
|
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
|
||||||
"""
|
"""
|
||||||
Get a training sampler configured for sequence parallelism.
|
Get a training sampler configured for sequence parallelism.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: The training dataset.
|
dataset: The training dataset
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Configured sequence parallel sampler.
|
Configured sequence parallel sampler.
|
||||||
@@ -86,7 +86,7 @@ class SequenceParallelMixin:
|
|||||||
shuffle=not self.args.curriculum_sampling,
|
shuffle=not self.args.curriculum_sampling,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
||||||
"""
|
"""
|
||||||
Get an evaluation sampler configured for sequence parallelism.
|
Get an evaluation sampler configured for sequence parallelism.
|
||||||
|
|
||||||
@@ -130,3 +130,53 @@ class SequenceParallelMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||||
|
|
||||||
|
def training_step(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: dict[str, torch.Tensor | Any],
|
||||||
|
num_items_in_batch: int | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Perform a training step on a batch of inputs. Overrides the
|
||||||
|
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||||
|
enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model to perform training step for.
|
||||||
|
inputs: Dictionary mapping.
|
||||||
|
"""
|
||||||
|
# Set up sequence parallelism for this step if enabled
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
self._update_ring_flash_attn_params(inputs)
|
||||||
|
|
||||||
|
# Proceed with normal training step
|
||||||
|
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
|
||||||
|
|
||||||
|
def prediction_step(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: dict[str, torch.Tensor | Any],
|
||||||
|
prediction_loss_only: bool,
|
||||||
|
ignore_keys: list[str] | None = None,
|
||||||
|
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
||||||
|
"""
|
||||||
|
Perform a prediction step on a batch of inputs. Overrides the
|
||||||
|
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||||
|
enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model to perform prediction step for.
|
||||||
|
inputs: Dictionary mapping of inputs.
|
||||||
|
prediction_loss_only: Whether to return only the loss.
|
||||||
|
ignore_keys: Keys to ignore in the inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (loss, logits, labels).
|
||||||
|
"""
|
||||||
|
# Set up sequence parallelism for this prediction step if enabled
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
self._update_ring_flash_attn_params(inputs)
|
||||||
|
|
||||||
|
# Proceed with normal prediction step
|
||||||
|
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
|
||||||
|
|||||||
@@ -13,10 +13,11 @@ from trl import (
|
|||||||
RewardTrainer,
|
RewardTrainer,
|
||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import TrainerMixins
|
from axolotl.core.trainers.mixins import RngLoaderMixin
|
||||||
|
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
|
class TRLPPOTrainer(PPOTrainer):
|
||||||
"""Wrapper for TRL PPO trainer to handle customizations"""
|
"""Wrapper for TRL PPO trainer to handle customizations"""
|
||||||
|
|
||||||
tag_names = ["axolotl", "ppo"]
|
tag_names = ["axolotl", "ppo"]
|
||||||
@@ -74,8 +75,10 @@ class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
|
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
|
||||||
"""Extend the base ORPOTrainer for axolotl helpers"""
|
"""
|
||||||
|
Extend the base ORPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
tag_names = ["axolotl", "orpo"]
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
@@ -152,14 +155,18 @@ class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
|
|||||||
return loss, metrics
|
return loss, metrics
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(TrainerMixins, KTOTrainer):
|
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
|
||||||
"""Extend the base KTOTrainer for axolotl helpers"""
|
"""
|
||||||
|
Extend the base KTOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
|
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
|
||||||
"""Extend the base CPOTrainer for axolotl helpers"""
|
"""
|
||||||
|
Extend the base CPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
tag_names = ["axolotl", "cpo"]
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
@@ -238,13 +245,17 @@ class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
|
|||||||
return loss, metrics
|
return loss, metrics
|
||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(TrainerMixins, RewardTrainer):
|
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
|
||||||
"""Extend the base RewardTrainer for axolotl helpers"""
|
"""
|
||||||
|
Extend the base RewardTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
tag_names = ["axolotl", "reward"]
|
tag_names = ["axolotl", "reward"]
|
||||||
|
|
||||||
|
|
||||||
class AxolotlPRMTrainer(TrainerMixins, PRMTrainer):
|
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
|
||||||
"""Extend the base trl.PRMTrainer for axolotl helpers"""
|
"""
|
||||||
|
Extend the base trl.PRMTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
tag_names = ["axolotl", "prm"]
|
tag_names = ["axolotl", "prm"]
|
||||||
|
|||||||
@@ -12,7 +12,9 @@ from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingMixins:
|
class AxolotlTrainingMixins:
|
||||||
"""Mixin class for the Axolotl training args."""
|
"""
|
||||||
|
Mixin class for the Axolotl training args.
|
||||||
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
|
|||||||
@@ -6,22 +6,11 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
|
|||||||
their sequence parallel version of Flash Attention 2.
|
their sequence parallel version of Flash Attention 2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
|
|
||||||
try:
|
|
||||||
from ring_flash_attn import update_ring_flash_attn_params
|
|
||||||
except ImportError:
|
|
||||||
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
|
||||||
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -43,120 +32,12 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
|||||||
Setter for ring attention group on this rank.
|
Setter for ring attention group on this rank.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ring_attn_group: Process group for ring attention.
|
Process group for ring attention.
|
||||||
"""
|
"""
|
||||||
global RING_ATTN_GROUP # pylint: disable=global-statement
|
global RING_ATTN_GROUP # pylint: disable=global-statement
|
||||||
RING_ATTN_GROUP = ring_attn_group
|
RING_ATTN_GROUP = ring_attn_group
|
||||||
|
|
||||||
|
|
||||||
def patch_flash_attention_for_sequential_batch(sequence_parallel_degree: int):
|
|
||||||
"""
|
|
||||||
Patch flash attention a second time to handle batched data. This is a hack to
|
|
||||||
accommodate certain RL trainers which batch data even when `micro_batch_size: 1` is
|
|
||||||
specified in the Axolotl config.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sequence_parallel_degree: Sequence parallelism factor.
|
|
||||||
"""
|
|
||||||
# Store the original flash attention function
|
|
||||||
original_flash_attention = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
|
|
||||||
|
|
||||||
def sequential_batch_flash_attention(
|
|
||||||
module: torch.nn.Module,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
attention_mask: torch.Tensor | None,
|
|
||||||
dropout: float = 0.0,
|
|
||||||
scaling: float | None = None,
|
|
||||||
sliding_window: int | None = None,
|
|
||||||
softcap: float | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> tuple[torch.Tensor, None]:
|
|
||||||
# Check if we have a batch dimension > 1
|
|
||||||
batch_size = query.shape[0]
|
|
||||||
|
|
||||||
if batch_size <= 1:
|
|
||||||
return original_flash_attention(
|
|
||||||
module,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
attention_mask,
|
|
||||||
dropout,
|
|
||||||
scaling,
|
|
||||||
sliding_window,
|
|
||||||
softcap,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process each item in the batch separately
|
|
||||||
outputs = []
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
|
||||||
# Extract single batch item
|
|
||||||
q_item = query[i:i+1]
|
|
||||||
k_item = key[i:i+1]
|
|
||||||
v_item = value[i:i+1]
|
|
||||||
|
|
||||||
# Handle attention mask - it might be None or have different shapes
|
|
||||||
mask_item = None
|
|
||||||
if attention_mask is not None:
|
|
||||||
# The mask could have different formats depending on implementation
|
|
||||||
if attention_mask.dim() >= 3 and attention_mask.shape[0] == batch_size:
|
|
||||||
mask_item = attention_mask[i:i+1]
|
|
||||||
else:
|
|
||||||
# For broadcast masks that don't have a batch dimension
|
|
||||||
mask_item = attention_mask
|
|
||||||
|
|
||||||
# At this point, inputs should already be partitioned by the sequence
|
|
||||||
# parallel data collator
|
|
||||||
batch_size = q_item.shape[0]
|
|
||||||
seq_len = q_item.shape[2]
|
|
||||||
packed_seq_lens = [seq_len] * batch_size
|
|
||||||
|
|
||||||
# Calculate the full sequence length across all GPUs in this SP group
|
|
||||||
total_seq_len = seq_len * sequence_parallel_degree
|
|
||||||
|
|
||||||
cu_seqlens = torch.cumsum(
|
|
||||||
torch.tensor(
|
|
||||||
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
|
||||||
),
|
|
||||||
dim=-1,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
cu_seqlens = F.pad(
|
|
||||||
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
|
||||||
)
|
|
||||||
|
|
||||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
|
||||||
|
|
||||||
# Call the original function for a single batch item
|
|
||||||
output, _ = original_flash_attention(
|
|
||||||
module,
|
|
||||||
q_item,
|
|
||||||
k_item,
|
|
||||||
v_item,
|
|
||||||
mask_item,
|
|
||||||
dropout,
|
|
||||||
scaling,
|
|
||||||
sliding_window,
|
|
||||||
softcap,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs.append(output)
|
|
||||||
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
# Concatenate results along batch dimension
|
|
||||||
concatenated_output = torch.cat(outputs, dim=0)
|
|
||||||
return concatenated_output, None
|
|
||||||
|
|
||||||
# Replace the original function with our sequential version
|
|
||||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = sequential_batch_flash_attention
|
|
||||||
|
|
||||||
|
|
||||||
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
|
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
|
||||||
"""
|
"""
|
||||||
Create ring attention group and substitute flash attn with ring flash attn.
|
Create ring attention group and substitute flash attn with ring flash attn.
|
||||||
@@ -217,4 +98,3 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
|||||||
substitute_hf_flash_attn(
|
substitute_hf_flash_attn(
|
||||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
||||||
)
|
)
|
||||||
patch_flash_attention_for_sequential_batch(sequence_parallel_degree)
|
|
||||||
|
|||||||
238
src/axolotl/monkeypatch/gemma3.py
Normal file
238
src/axolotl/monkeypatch/gemma3.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
"""Monkeypatch for gemma3 conditional generation forward to fix loss exploding"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.models.gemma3.modeling_gemma3 import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
GEMMA3_INPUTS_DOCSTRING,
|
||||||
|
Gemma3CausalLMOutputWithPast,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_torchdynamo_compiling,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
def new_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**lm_kwargs,
|
||||||
|
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||||
|
|
||||||
|
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
|
||||||
|
|
||||||
|
>>> prompt = "answer en Where is the cow standing?"
|
||||||
|
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(**inputs, max_length=30)
|
||||||
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"answer en Where is the cow standing?\nbeach"
|
||||||
|
```"""
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
is_training = token_type_ids is not None and labels is not None
|
||||||
|
|
||||||
|
# Replace image id with PAD if the image token is OOV, to avoid index-errors
|
||||||
|
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
||||||
|
special_image_mask = input_ids == self.config.image_token_index
|
||||||
|
llm_input_ids = input_ids.clone()
|
||||||
|
llm_input_ids[special_image_mask] = 0
|
||||||
|
else:
|
||||||
|
llm_input_ids = input_ids
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = (
|
||||||
|
past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
)
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens,
|
||||||
|
past_seen_tokens + inputs_embeds.shape[1],
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge text and images
|
||||||
|
if pixel_values is not None:
|
||||||
|
image_features = self.get_image_features(pixel_values)
|
||||||
|
|
||||||
|
if input_ids is None:
|
||||||
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||||
|
torch.tensor(
|
||||||
|
self.config.image_token_index,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||||
|
-1
|
||||||
|
)
|
||||||
|
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||||
|
inputs_embeds.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not is_torchdynamo_compiling()
|
||||||
|
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||||
|
):
|
||||||
|
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of images does not match number of special image tokens in the input text. "
|
||||||
|
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
||||||
|
"tokens from image embeddings."
|
||||||
|
)
|
||||||
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||||
|
|
||||||
|
# mask out pad-token-ids in labels for BC
|
||||||
|
if labels is not None and self.pad_token_id in labels:
|
||||||
|
logger.warning_once(
|
||||||
|
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
||||||
|
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
||||||
|
)
|
||||||
|
labels = torch.where(
|
||||||
|
input_ids == self.pad_token_id, self.config.ignore_index, labels
|
||||||
|
)
|
||||||
|
|
||||||
|
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
past_key_values,
|
||||||
|
cache_position,
|
||||||
|
inputs_embeds,
|
||||||
|
is_training,
|
||||||
|
)
|
||||||
|
outputs = self.language_model(
|
||||||
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
logits_to_keep=logits_to_keep,
|
||||||
|
**lm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = outputs[0]
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
if attention_mask is not None:
|
||||||
|
# Get the shifted attention mask
|
||||||
|
shift_attention_mask = attention_mask[:, -logits.shape[1] + 1 :].to(
|
||||||
|
logits.device
|
||||||
|
) # +1 for shift
|
||||||
|
|
||||||
|
# Filter logits and labels based on attention mask
|
||||||
|
valid_indices = shift_attention_mask != 0
|
||||||
|
filtered_logits = logits[..., :-1, :][valid_indices]
|
||||||
|
filtered_labels = labels[..., 1:][valid_indices.to(labels.device)]
|
||||||
|
|
||||||
|
# TODO: do we need to handle num_items_in_batch given we filter the logits and labels?
|
||||||
|
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits=filtered_logits,
|
||||||
|
labels=None, # we pass shift_labels
|
||||||
|
shift_labels=filtered_labels,
|
||||||
|
vocab_size=self.config.text_config.vocab_size,
|
||||||
|
**lm_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Standard case without filtering
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits=logits,
|
||||||
|
labels=labels,
|
||||||
|
vocab_size=self.config.text_config.vocab_size,
|
||||||
|
**lm_kwargs,
|
||||||
|
)
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return Gemma3CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
image_hidden_states=image_features if pixel_values is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_gemma3conditionalgeneration_forward():
|
||||||
|
from transformers.models.gemma3.modeling_gemma3 import (
|
||||||
|
Gemma3ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
Gemma3ForConditionalGeneration.forward = new_forward
|
||||||
@@ -252,12 +252,38 @@ def apply_lora_kernel_patches(
|
|||||||
LOG.setLevel(logging.INFO)
|
LOG.setLevel(logging.INFO)
|
||||||
|
|
||||||
# Choose activation based on model type
|
# Choose activation based on model type
|
||||||
activation = model.config.hidden_act
|
activation = None
|
||||||
|
text_config = (
|
||||||
|
model.config.get_text_config()
|
||||||
|
if hasattr(model.config, "get_text_config")
|
||||||
|
else model.config
|
||||||
|
)
|
||||||
|
if hasattr(text_config, "hidden_act"):
|
||||||
|
activation = text_config.hidden_act
|
||||||
|
elif hasattr(text_config, "hidden_activation"):
|
||||||
|
activation = text_config.hidden_activation
|
||||||
|
|
||||||
|
# map activation to supported activation
|
||||||
|
if "gelu" in activation:
|
||||||
|
# gemma3 uses gelu_pytorch_tanh
|
||||||
|
activation = "gelu"
|
||||||
|
|
||||||
if activation not in SUPPORTED_ACTIVATIONS:
|
if activation not in SUPPORTED_ACTIVATIONS:
|
||||||
raise NotImplementedError(f"Activation {activation} is not supported")
|
raise NotImplementedError(f"Activation {activation} is not supported")
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
# check for multimodal models first
|
||||||
|
if hasattr(model, "language_model"):
|
||||||
|
layers = model.language_model.model.layers
|
||||||
|
elif hasattr(model, "model"):
|
||||||
|
layers = model.model.model.layers
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
|
||||||
|
)
|
||||||
|
|
||||||
# Patch each layer
|
# Patch each layer
|
||||||
for layer in model.model.model.layers:
|
for layer in layers:
|
||||||
# Add QKV, O fallback implementations to start
|
# Add QKV, O fallback implementations to start
|
||||||
# These will be overwritten later (if some conditions apply)
|
# These will be overwritten later (if some conditions apply)
|
||||||
layer.self_attn.apply_qkv = types.MethodType(
|
layer.self_attn.apply_qkv = types.MethodType(
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ def resolve_dtype(cfg):
|
|||||||
cfg.bf16 = False
|
cfg.bf16 = False
|
||||||
else:
|
else:
|
||||||
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
||||||
|
torch.backends.cudnn.allow_tf32 = cfg.tf32 or False
|
||||||
if cfg.bf16:
|
if cfg.bf16:
|
||||||
cfg.fp16 = False
|
cfg.fp16 = False
|
||||||
|
|
||||||
|
|||||||
@@ -535,6 +535,15 @@ class ModelLoader:
|
|||||||
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||||
|
|
||||||
def apply_patches(self) -> None:
|
def apply_patches(self) -> None:
|
||||||
|
# patch gemma3 conditional generation forward before loading plugins
|
||||||
|
# as it could be overridden by plugins
|
||||||
|
if self.cfg.model_config_type == "gemma3":
|
||||||
|
from axolotl.monkeypatch.gemma3 import (
|
||||||
|
patch_gemma3conditionalgeneration_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_gemma3conditionalgeneration_forward()
|
||||||
|
|
||||||
# load any patches from plugins
|
# load any patches from plugins
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
|
|
||||||
@@ -1351,7 +1360,9 @@ def load_model(
|
|||||||
reference_model: bool = False,
|
reference_model: bool = False,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||||
"""Load a model for a given configuration and tokenizer."""
|
"""
|
||||||
|
Load a model for a given configuration and tokenizer.
|
||||||
|
"""
|
||||||
model_loader = ModelLoader(
|
model_loader = ModelLoader(
|
||||||
cfg,
|
cfg,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -1360,16 +1371,12 @@ def load_model(
|
|||||||
reference_model=reference_model,
|
reference_model=reference_model,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_loader.load_model()
|
return model_loader.load_model()
|
||||||
|
|
||||||
|
|
||||||
def load_adapter(
|
def load_adapter(model, cfg, adapter, inference=False):
|
||||||
model: PreTrainedModel,
|
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
cfg: DictDefault,
|
|
||||||
adapter: str | None,
|
|
||||||
inference: bool = False,
|
|
||||||
) -> tuple[PreTrainedModel, PeftConfig | None]:
|
|
||||||
if adapter is None:
|
if adapter is None:
|
||||||
return model, None
|
return model, None
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
@@ -1382,9 +1389,8 @@ def load_adapter(
|
|||||||
raise NotImplementedError(f"{adapter} peft adapter not available")
|
raise NotImplementedError(f"{adapter} peft adapter not available")
|
||||||
|
|
||||||
|
|
||||||
def load_llama_adapter(
|
def load_llama_adapter(model, cfg):
|
||||||
model: PreTrainedModel, cfg: DictDefault
|
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
) -> tuple[PreTrainedModel, PeftConfig | None]:
|
|
||||||
from peft import AdaptionPromptConfig, get_peft_model
|
from peft import AdaptionPromptConfig, get_peft_model
|
||||||
|
|
||||||
peft_config = AdaptionPromptConfig(
|
peft_config = AdaptionPromptConfig(
|
||||||
@@ -1408,7 +1414,7 @@ def load_llama_adapter(
|
|||||||
return model, peft_config
|
return model, peft_config
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_names(model: PreTrainedModel):
|
def find_all_linear_names(model):
|
||||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
||||||
lora_module_names = set()
|
lora_module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
|
|||||||
@@ -1224,17 +1224,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
):
|
):
|
||||||
capabilities = data.get("capabilities")
|
capabilities = data.get("capabilities")
|
||||||
is_fsdp = data.get("fsdp") is not None
|
is_fsdp = data.get("fsdp") is not None
|
||||||
is_deepspeed = data.get("deepspeed") is not None
|
|
||||||
|
|
||||||
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
||||||
if is_fsdp:
|
if is_fsdp:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP."
|
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP."
|
||||||
)
|
)
|
||||||
if is_deepspeed:
|
|
||||||
raise ValueError(
|
|
||||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with DeepSpeed."
|
|
||||||
)
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -1290,3 +1285,5 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1."
|
f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|||||||
Reference in New Issue
Block a user