Compare commits
4 Commits
v0.6.0
...
feat/pref_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8428b3f2c7 | ||
|
|
02629c7cdf | ||
|
|
78a4aa86d6 | ||
|
|
d009ead101 |
2
.github/workflows/pypi.yml
vendored
2
.github/workflows/pypi.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install wheel packaging
|
||||
pip3 install -e .
|
||||
pip3 install --no-build-isolation -e .
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Extract tag name
|
||||
|
||||
11
.github/workflows/tests-nightly.yml
vendored
11
.github/workflows/tests-nightly.yml
vendored
@@ -44,6 +44,11 @@ jobs:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
||||
@@ -60,11 +65,15 @@ jobs:
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging
|
||||
pip3 install -U -e .
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
16
.github/workflows/tests.yml
vendored
16
.github/workflows/tests.yml
vendored
@@ -78,11 +78,15 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install -U -e .
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
axolotl --help
|
||||
@@ -120,7 +124,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools wheel
|
||||
pip3 install --upgrade packaging setuptools setuptools_scm build wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
@@ -129,12 +133,16 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
python3 setup.py sdist
|
||||
pip3 install dist/axolotl*.tar.gz
|
||||
python -m build --no-isolation --sdist
|
||||
pip3 install --no-build-isolation dist/axolotl*.tar.gz
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
include requirements.txt
|
||||
include README.md
|
||||
include LICENSE
|
||||
include src/setuptools_axolotl_dynamic_dependencies.py
|
||||
recursive-include axolotl *.py
|
||||
|
||||
@@ -112,7 +112,7 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
|
||||
**Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1.
|
||||
|
||||
```bash
|
||||
pip3 install axolotl[flash-attn,deepspeed]
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
|
||||
# download examples and optionally deepspeed configs to the local path
|
||||
axolotl fetch examples
|
||||
@@ -131,7 +131,7 @@ from source.
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
pip3 install packaging ninja
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
### Axolotl CLI Usage
|
||||
@@ -320,7 +320,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
||||
3. Install Axolotl along with python dependencies
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
4. (Optional) Login to Huggingface to use gated models/datasets.
|
||||
```bash
|
||||
@@ -399,7 +399,7 @@ Please use WSL or Docker!
|
||||
|
||||
Use the below instead of the install method in QuickStart.
|
||||
```
|
||||
pip3 install -e '.'
|
||||
pip3 install --no-build-isolation -e '.'
|
||||
```
|
||||
More info: [mac.md](/docs/mac.qmd)
|
||||
|
||||
|
||||
@@ -31,9 +31,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
fi
|
||||
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/
|
||||
|
||||
@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
|
||||
@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
# So we can test the Docker image
|
||||
|
||||
@@ -52,7 +52,7 @@ export GPU_ARCHS="gfx90a"
|
||||
cd flash-attention
|
||||
export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')
|
||||
patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch
|
||||
pip install .
|
||||
pip install --no-build-isolation .
|
||||
```
|
||||
|
||||
### 6. Install Axolotl
|
||||
@@ -63,7 +63,7 @@ Clone and install Axolotl:
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl
|
||||
cd axolotl
|
||||
pip install packaging ninja
|
||||
pip install -e .
|
||||
pip install --no-build-isolation -e .
|
||||
```
|
||||
|
||||
### 7. Apply xformers Workaround
|
||||
|
||||
@@ -71,7 +71,7 @@ Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/us
|
||||
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
#### Remote Hosts
|
||||
@@ -212,7 +212,7 @@ You will now be in the container. Next, perform an editable install of Axolotl:
|
||||
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
### Attach To Container
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install axolotl[deepspeed]"
|
||||
"!pip install --no-build-isolation axolotl[deepspeed]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -17,3 +17,10 @@ Homepage = "https://axolotl-ai-cloud.github.io/axolotl/"
|
||||
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
|
||||
|
||||
[tool.setuptools_scm]
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.cmdclass]
|
||||
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
|
||||
|
||||
@@ -13,5 +13,5 @@ cd /workspace
|
||||
rm -rf /workspace/axolotl
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
pip install --no-deps -e .
|
||||
pip install --no-build-isolation --no-deps -e .
|
||||
```
|
||||
|
||||
@@ -14,17 +14,22 @@ import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from datasets import Dataset
|
||||
from liger_kernel.chunked_loss.fused_linear_preference import (
|
||||
LigerFusedLinearPreferenceBase,
|
||||
)
|
||||
from packaging import version
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from torch import amp, nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import (
|
||||
@@ -1077,6 +1082,15 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
self.dataset_tags = dataset_tags
|
||||
self.optimizer = None
|
||||
|
||||
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
|
||||
|
||||
self.liger_loss = LigerFusedLinearDPOLoss(
|
||||
ignore_index=self.label_pad_token_id,
|
||||
beta=self.beta,
|
||||
compute_nll_loss=True, # not same as rpo_alpha hasattr(self.args, "rpo_alpha") and self.args.rpo_alpha is not None,
|
||||
use_ref_model=not self.reference_free,
|
||||
)
|
||||
|
||||
def create_optimizer(self):
|
||||
if self.args.loraplus_lr_ratio is None:
|
||||
return super().create_optimizer()
|
||||
@@ -1180,6 +1194,309 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
# transformers<=4.46
|
||||
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
):
|
||||
"""Compute the DPO loss and other metrics using Liger kernel."""
|
||||
# return super().get_batch_loss_metrics(model, batch, train_eval)
|
||||
if not self.liger_loss:
|
||||
raise ValueError("Liger loss not initialized")
|
||||
|
||||
metrics = {}
|
||||
|
||||
model_output = self.concatenated_forward(model, batch)
|
||||
|
||||
# Get the lm_head weights and bias
|
||||
lin_weight = model.lm_head.weight
|
||||
lin_bias = getattr(model.lm_head, "bias", None)
|
||||
|
||||
hidden_states = model_output["hidden_states"]
|
||||
labels = model_output["labels"]
|
||||
|
||||
if not self.reference_free:
|
||||
# Adapted from DPO's compute_ref_log_probs
|
||||
compte_ref_context_manager = (
|
||||
amp.autocast("cuda")
|
||||
if self._peft_has_been_casted_to_bf16
|
||||
else nullcontext()
|
||||
)
|
||||
with torch.no_grad(), compte_ref_context_manager: # type: ignore
|
||||
if self.ref_model is None:
|
||||
with self.null_ref_context():
|
||||
ref_model_output = self.concatenated_forward(self.model, batch)
|
||||
ref_weight = self.model.lm_head.weight
|
||||
ref_bias = getattr(self.model.lm_head, "bias", None)
|
||||
|
||||
ref_hidden_states = ref_model_output["hidden_states"]
|
||||
|
||||
else:
|
||||
ref_model_output = self.concatenated_forward(self.ref_model, batch)
|
||||
ref_weight = self.ref_model.lm_head.weight
|
||||
ref_bias = getattr(self.ref_model.lm_head, "bias", None)
|
||||
|
||||
ref_hidden_states = ref_model_output["hidden_states"]
|
||||
(
|
||||
ref_chosen_logps,
|
||||
ref_rejected_logps,
|
||||
_ref_chosen_logits,
|
||||
_ref_rejected_logits,
|
||||
_ref_chosen_nll_loss,
|
||||
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
||||
input_chunk=ref_hidden_states,
|
||||
weight=ref_weight,
|
||||
target_chunk=labels,
|
||||
bias=ref_bias,
|
||||
# ignore_index=ignore_index,
|
||||
compute_nll_loss=False,
|
||||
)
|
||||
|
||||
else:
|
||||
ref_hidden_states = None
|
||||
ref_weight = None
|
||||
ref_bias = None
|
||||
|
||||
# Compute loss using Liger kernel
|
||||
loss, return_vars = self.liger_loss(
|
||||
lin_weight=lin_weight,
|
||||
_input=hidden_states,
|
||||
target=labels,
|
||||
bias=lin_bias, # TODO: check whether to pass bias as FCLE doesn't
|
||||
ref_input=ref_hidden_states,
|
||||
ref_weight=ref_weight,
|
||||
ref_bias=ref_bias,
|
||||
)
|
||||
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits_mean,
|
||||
policy_rejected_logits_mean,
|
||||
policy_nll_loss,
|
||||
) = return_vars
|
||||
|
||||
# Calculate rewards
|
||||
if not self.reference_free:
|
||||
chosen_rewards = (
|
||||
self.beta * (policy_chosen_logps - (ref_chosen_logps)).detach()
|
||||
)
|
||||
rejected_rewards = (
|
||||
self.beta * (policy_rejected_logps - (ref_rejected_logps)).detach()
|
||||
)
|
||||
|
||||
else:
|
||||
chosen_rewards = self.beta * policy_chosen_logps
|
||||
rejected_rewards = self.beta * policy_rejected_logps
|
||||
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics.update(
|
||||
{
|
||||
f"{prefix}rewards/chosen": chosen_rewards.mean().cpu(),
|
||||
f"{prefix}rewards/rejected": rejected_rewards.mean().cpu(),
|
||||
f"{prefix}rewards/accuracies": reward_accuracies.mean().cpu(),
|
||||
f"{prefix}rewards/margins": (chosen_rewards - rejected_rewards)
|
||||
.mean()
|
||||
.cpu(),
|
||||
f"{prefix}logps/chosen": policy_chosen_logps.mean().cpu(),
|
||||
f"{prefix}logps/rejected": policy_rejected_logps.mean().cpu(),
|
||||
f"{prefix}logits/chosen": policy_chosen_logits_mean.cpu(),
|
||||
f"{prefix}logits/rejected": policy_rejected_logits_mean.cpu(),
|
||||
}
|
||||
)
|
||||
|
||||
if hasattr(self.args, "rpo_alpha") and self.args.rpo_alpha is not None:
|
||||
metrics[f"{prefix}nll_loss"] = policy_nll_loss.cpu()
|
||||
|
||||
# TODO: Handle use_weighting, aux_loss_enabled as in upstream
|
||||
|
||||
return loss, metrics
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
||||
):
|
||||
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
||||
|
||||
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
||||
|
||||
Overridden base function to return the hidden states and labels for the loss calculation.
|
||||
"""
|
||||
num_examples = batch["prompt_input_ids"].shape[0] # type: ignore
|
||||
|
||||
concatenated_batch = self.concatenated_inputs(
|
||||
batch, padding_value=self.padding_value
|
||||
)
|
||||
|
||||
model_kwargs = {}
|
||||
if self.aux_loss_enabled:
|
||||
model_kwargs["output_router_logits"] = True
|
||||
|
||||
# Add to get the hidden states for the loss
|
||||
model_kwargs["output_hidden_states"] = True
|
||||
|
||||
# Add the pixel values and attention masks for vision models
|
||||
if "pixel_values" in concatenated_batch:
|
||||
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
|
||||
if "pixel_attention_mask" in concatenated_batch:
|
||||
model_kwargs["pixel_attention_mask"] = concatenated_batch[
|
||||
"pixel_attention_mask"
|
||||
]
|
||||
if "image_sizes" in concatenated_batch:
|
||||
model_kwargs["image_sizes"] = concatenated_batch["image_sizes"]
|
||||
|
||||
prompt_input_ids = concatenated_batch["prompt_input_ids"]
|
||||
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
|
||||
completion_input_ids = concatenated_batch["completion_input_ids"]
|
||||
completion_attention_mask = concatenated_batch["completion_attention_mask"]
|
||||
if self.is_encoder_decoder:
|
||||
labels = completion_input_ids
|
||||
labels[completion_attention_mask == 0] = self.label_pad_token_id
|
||||
outputs = model(
|
||||
input_ids=prompt_input_ids,
|
||||
attention_mask=prompt_attention_mask,
|
||||
labels=labels, # we need the labels for the logits to be returned
|
||||
**model_kwargs,
|
||||
)
|
||||
logits = outputs.logits
|
||||
hidden_states = outputs.decoder_hidden_states[-1]
|
||||
loss_mask = completion_attention_mask.bool()
|
||||
else:
|
||||
# Concatenate the prompt and completion inputs
|
||||
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
|
||||
attention_mask = torch.cat(
|
||||
(prompt_attention_mask, completion_attention_mask), dim=1
|
||||
)
|
||||
# Mask the prompt but not the completion for the loss
|
||||
loss_mask = torch.cat(
|
||||
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Flush left to reduce the memory usage
|
||||
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
|
||||
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
|
||||
for i in range(attention_mask.size(0)):
|
||||
first_one_idx = torch.nonzero(attention_mask[i])[0].item()
|
||||
input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) # type: ignore
|
||||
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) # type: ignore
|
||||
loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx) # type: ignore
|
||||
|
||||
# Get the first column idx that is all zeros and remove every column after that
|
||||
empty_cols = torch.sum(attention_mask, dim=0) == 0
|
||||
first_empty_col = (
|
||||
torch.nonzero(empty_cols)[0].item()
|
||||
if empty_cols.any()
|
||||
else attention_mask.size(1)
|
||||
)
|
||||
input_ids = input_ids[:, :first_empty_col] # type: ignore
|
||||
attention_mask = attention_mask[:, :first_empty_col] # type: ignore
|
||||
loss_mask = loss_mask[:, :first_empty_col] # type: ignore
|
||||
|
||||
# Truncate right
|
||||
if self.args.max_length is not None:
|
||||
input_ids = input_ids[:, : self.args.max_length]
|
||||
attention_mask = attention_mask[:, : self.args.max_length]
|
||||
loss_mask = loss_mask[:, : self.args.max_length]
|
||||
|
||||
# if self.use_num_logits_to_keep:
|
||||
# # Compute num_logits_to_keep based on loss_mask pattern:
|
||||
# # [[0, 0, 0, x, x, x, x],
|
||||
# # [0, 0, 0, x, x, x, 0]]
|
||||
# # ^ start computing logits from here ([:, -(7-3+1):])
|
||||
# first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
|
||||
# num_logits_to_keep = loss_mask.shape[1] - first_compute_index
|
||||
# model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label
|
||||
|
||||
outputs = model(
|
||||
input_ids=input_ids, attention_mask=attention_mask, **model_kwargs
|
||||
)
|
||||
|
||||
# Offset the logits by one to align with the labels
|
||||
logits = outputs.logits[:, :-1, :]
|
||||
hidden_states = outputs.hidden_states[-1][:, :-1, :]
|
||||
labels = input_ids[:, 1:].clone()
|
||||
loss_mask = loss_mask[:, 1:].bool()
|
||||
|
||||
# if self.use_num_logits_to_keep:
|
||||
# # Align labels with logits
|
||||
# # logits: -, -, [x2, x3, x4, x5, x6]
|
||||
# # ^ --------- ^ after logits[:, :-1, :]
|
||||
# # labels: [y0, y1, y2, y3, y4, y5, y6]
|
||||
# # ^ --------- ^ with num_logits_to_keep=4, [:, -4:]
|
||||
# # loss_mask: [0, 0, 0, 1, 1, 1, 1]
|
||||
# labels = labels[:, -num_logits_to_keep:]
|
||||
# loss_mask = loss_mask[:, -num_logits_to_keep:]
|
||||
# hidden_states = hidden_states[:, -num_logits_to_keep:, :]
|
||||
|
||||
if logits.shape[:2] != labels.shape[:2]:
|
||||
# for llava, the returned logits include the image tokens (placed before the text tokens)
|
||||
seq_len = labels.shape[1]
|
||||
logits = logits[:, -seq_len:]
|
||||
hidden_states = hidden_states[:, -seq_len:]
|
||||
|
||||
# Compute the log probabilities of the labels
|
||||
labels[
|
||||
~loss_mask
|
||||
] = 0 # dummy token; we'll ignore the losses on these tokens later
|
||||
per_token_logps = torch.gather(
|
||||
logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)
|
||||
).squeeze(2)
|
||||
per_token_logps[~loss_mask] = 0
|
||||
all_logps = per_token_logps.sum(-1)
|
||||
|
||||
output = {}
|
||||
|
||||
if self.use_weighting:
|
||||
with torch.no_grad():
|
||||
# Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827
|
||||
logprobs = F.log_softmax(logits, dim=-1)
|
||||
weights_adjustment_factor = torch.logsumexp(
|
||||
2 * logprobs, dim=-1
|
||||
) # same as sum(probs**2) in log space
|
||||
per_token_logps_adjusted = per_token_logps - weights_adjustment_factor
|
||||
all_weights = (per_token_logps_adjusted * loss_mask).sum(
|
||||
-1
|
||||
) / loss_mask.sum(-1)
|
||||
chosen_weights = all_weights[:num_examples]
|
||||
rejected_weights = all_weights[num_examples:]
|
||||
output["policy_weights"] = torch.clamp(
|
||||
torch.exp(chosen_weights + rejected_weights), max=1
|
||||
)
|
||||
|
||||
if self.args.rpo_alpha is not None:
|
||||
# Only use the chosen logits for the RPO loss
|
||||
chosen_logits = logits[:num_examples]
|
||||
chosen_labels = labels[:num_examples]
|
||||
|
||||
# Compute the log probabilities of the labels
|
||||
output["nll_loss"] = F.cross_entropy(
|
||||
torch.flatten(chosen_logits, end_dim=1),
|
||||
torch.flatten(chosen_labels, end_dim=1),
|
||||
ignore_index=0,
|
||||
)
|
||||
|
||||
if self.loss_type == "ipo":
|
||||
all_logps = all_logps / loss_mask.sum(-1)
|
||||
|
||||
output["chosen_logps"] = all_logps[:num_examples]
|
||||
output["rejected_logps"] = all_logps[num_examples:]
|
||||
output["mean_chosen_logits"] = logits[:num_examples][
|
||||
loss_mask[:num_examples]
|
||||
].mean()
|
||||
output["mean_rejected_logits"] = logits[num_examples:][
|
||||
loss_mask[num_examples:]
|
||||
].mean()
|
||||
output["hidden_states"] = hidden_states
|
||||
output["labels"] = labels
|
||||
|
||||
if self.aux_loss_enabled:
|
||||
output["aux_loss"] = outputs.aux_loss
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
@@ -2163,6 +2480,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
|
||||
report_to = []
|
||||
if self.cfg.use_wandb:
|
||||
report_to.append("wandb")
|
||||
if self.cfg.wandb_name:
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
|
||||
training_args_kwargs["report_to"] = report_to
|
||||
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
output_dir=self.cfg.output_dir,
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
|
||||
@@ -66,10 +66,7 @@ class EvalFirstStepCallback(
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if (
|
||||
args.evaluation_strategy == IntervalStrategy.STEPS
|
||||
and state.global_step == 1
|
||||
):
|
||||
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
||||
control.should_evaluate = True
|
||||
return control
|
||||
|
||||
|
||||
104
src/setuptools_axolotl_dynamic_dependencies.py
Normal file
104
src/setuptools_axolotl_dynamic_dependencies.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
dynamic requirements for axolotl
|
||||
"""
|
||||
import platform
|
||||
import re
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from setuptools.command.build_py import build_py as _build_py
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def parse_requirements():
|
||||
_install_requires = []
|
||||
_dependency_links = []
|
||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
lines = [r.strip() for r in requirements_file.readlines()]
|
||||
for line in lines:
|
||||
is_extras = (
|
||||
"flash-attn" in line
|
||||
or "flash-attention" in line
|
||||
or "deepspeed" in line
|
||||
or "mamba-ssm" in line
|
||||
or "lion-pytorch" in line
|
||||
)
|
||||
if line.startswith("--extra-index-url"):
|
||||
# Handle custom index URLs
|
||||
_, url = line.split()
|
||||
_dependency_links.append(url)
|
||||
elif not is_extras and line and line[0] != "#":
|
||||
# Handle standard packages
|
||||
_install_requires.append(line)
|
||||
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
||||
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
||||
|
||||
if "Darwin" in platform.system():
|
||||
# don't install xformers on MacOS
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
else:
|
||||
# detect the version of torch already installed
|
||||
# and set it so dependencies don't clobber the torch version
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
except PackageNotFoundError:
|
||||
torch_version = "2.5.1"
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
if version_match:
|
||||
major, minor, patch = version_match.groups()
|
||||
major, minor = int(major), int(minor)
|
||||
patch = (
|
||||
int(patch) if patch is not None else 0
|
||||
) # Default patch to 0 if not present
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.28.post3")
|
||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||
elif (major, minor) >= (2, 4):
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.28.post1")
|
||||
elif (major, minor) >= (2, 3):
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.26.post1")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
elif (major, minor) >= (2, 2):
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.25.post1")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.23.post1")
|
||||
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
return _install_requires, _dependency_links
|
||||
|
||||
|
||||
class BuildPyCommand(_build_py):
|
||||
"""
|
||||
custom build_py command to parse dynamic requirements
|
||||
"""
|
||||
|
||||
def finalize_options(self):
|
||||
super().finalize_options()
|
||||
install_requires, _ = parse_requirements()
|
||||
self.distribution.install_requires = install_requires
|
||||
Reference in New Issue
Block a user