Compare commits
2 Commits
kto_fix
...
optimizers
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76bb09784d | ||
|
|
0542c7dd56 |
5
.github/workflows/main.yml
vendored
5
.github/workflows/main.yml
vendored
@@ -88,11 +88,6 @@ jobs:
|
|||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.6.0
|
|
||||||
axolotl_extras:
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
5
.github/workflows/nightlies.yml
vendored
5
.github/workflows/nightlies.yml
vendored
@@ -80,11 +80,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.6.0
|
|
||||||
axolotl_extras:
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ COPY scripts/motd /etc/motd
|
|||||||
|
|
||||||
RUN pip install jupyterlab notebook ipywidgets && \
|
RUN pip install jupyterlab notebook ipywidgets && \
|
||||||
jupyter lab clean
|
jupyter lab clean
|
||||||
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
|
RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
||||||
mkdir -p ~/.ssh && \
|
mkdir -p ~/.ssh && \
|
||||||
chmod 700 ~/.ssh && \
|
chmod 700 ~/.ssh && \
|
||||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||||
|
|||||||
@@ -154,6 +154,8 @@ datasets:
|
|||||||
content: value
|
content: value
|
||||||
# ...
|
# ...
|
||||||
|
|
||||||
|
message_property_mappings:
|
||||||
|
|
||||||
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
||||||
roles:
|
roles:
|
||||||
user: ["human", "user"]
|
user: ["human", "user"]
|
||||||
@@ -554,13 +556,6 @@ special_tokens:
|
|||||||
# Add extra tokens.
|
# Add extra tokens.
|
||||||
tokens:
|
tokens:
|
||||||
|
|
||||||
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
|
|
||||||
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
|
|
||||||
# Can be checked if they exist in tokenizer.json added_tokens.
|
|
||||||
added_tokens_overrides: # Dict[int, str]
|
|
||||||
# 128041: "<|im_start|>"
|
|
||||||
# 128042: "<|im_end|>"
|
|
||||||
|
|
||||||
# FSDP
|
# FSDP
|
||||||
fsdp:
|
fsdp:
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
|
|||||||
@@ -74,10 +74,6 @@ datasets:
|
|||||||
train_on_eos:
|
train_on_eos:
|
||||||
```
|
```
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
|
|
||||||
:::
|
|
||||||
|
|
||||||
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
@@ -52,7 +52,3 @@ description: Frequently asked questions
|
|||||||
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
|
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
|
||||||
|
|
||||||
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
|
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
|
||||||
|
|
||||||
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
|
|
||||||
|
|
||||||
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.
|
|
||||||
|
|||||||
@@ -28,17 +28,6 @@ val_set_size: 0.1
|
|||||||
eval_steps: 100
|
eval_steps: 100
|
||||||
```
|
```
|
||||||
|
|
||||||
Bradley-Terry chat templates expect single-turn conversations in the following format:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"system": "...", // optional
|
|
||||||
"input": "...",
|
|
||||||
"chosen": "...",
|
|
||||||
"rejected": "..."
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Process Reward Models (PRM)
|
### Process Reward Models (PRM)
|
||||||
|
|
||||||
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
|
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
|
||||||
@@ -56,5 +45,3 @@ datasets:
|
|||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
eval_steps: 100
|
eval_steps: 100
|
||||||
```
|
```
|
||||||
|
|
||||||
Please see [stepwise_supervised](dataset-formats/stepwise_supervised.qmd) for more details on the dataset format.
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ title: "RLHF (Beta)"
|
|||||||
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
||||||
back-to-top-navigation: true
|
back-to-top-navigation: true
|
||||||
toc: true
|
toc: true
|
||||||
toc-expand: 2
|
|
||||||
toc-depth: 4
|
toc-depth: 4
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -529,7 +528,6 @@ trl:
|
|||||||
vllm_gpu_memory_utilization: 0.15
|
vllm_gpu_memory_utilization: 0.15
|
||||||
num_generations: 4
|
num_generations: 4
|
||||||
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
||||||
reward_weights: [1.0]
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: openai/gsm8k
|
- path: openai/gsm8k
|
||||||
name: main
|
name: main
|
||||||
@@ -538,8 +536,6 @@ datasets:
|
|||||||
|
|
||||||
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
|
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
|
||||||
|
|
||||||
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
|
|
||||||
|
|
||||||
### Using local dataset files
|
### Using local dataset files
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
@@ -62,5 +62,5 @@ antlr4-python3-runtime==4.13.2
|
|||||||
torchao==0.7.0
|
torchao==0.7.0
|
||||||
schedulefree==1.3.0
|
schedulefree==1.3.0
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.6
|
axolotl-contribs-lgpl==0.0.3
|
||||||
axolotl-contribs-mit==0.0.3
|
axolotl-contribs-mit==0.0.3
|
||||||
|
|||||||
@@ -24,5 +24,5 @@ if cce_spec:
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
|
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ class ModalCloud(Cloud):
|
|||||||
[
|
[
|
||||||
# Random id for cache busting of branch commits
|
# Random id for cache busting of branch commits
|
||||||
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
|
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
|
||||||
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch} && git pull",
|
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -270,7 +270,6 @@ def _preprocess(config_yaml: str, volumes=None):
|
|||||||
|
|
||||||
|
|
||||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||||
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
|
|
||||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||||
f_out.write(config_yaml)
|
f_out.write(config_yaml)
|
||||||
run_folder = "/workspace/mounts"
|
run_folder = "/workspace/mounts"
|
||||||
@@ -289,7 +288,6 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def _lm_eval(config_yaml: str, volumes=None):
|
def _lm_eval(config_yaml: str, volumes=None):
|
||||||
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
|
|
||||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||||
f_out.write(config_yaml)
|
f_out.write(config_yaml)
|
||||||
run_folder = "/workspace/mounts"
|
run_folder = "/workspace/mounts"
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""CLI to run training on a model."""
|
"""CLI to run training on a model."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -35,8 +34,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
check_user_token()
|
||||||
check_user_token()
|
|
||||||
|
|
||||||
if cfg.rl:
|
if cfg.rl:
|
||||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class TokenizedChatDataset(Dataset):
|
|||||||
process_or_cpu_count: int = (
|
process_or_cpu_count: int = (
|
||||||
process_count or os.cpu_count() # type: ignore[assignment]
|
process_count or os.cpu_count() # type: ignore[assignment]
|
||||||
)
|
)
|
||||||
num_proc = min(32, process_or_cpu_count)
|
num_proc = min(64, process_or_cpu_count)
|
||||||
features = data.features.keys()
|
features = data.features.keys()
|
||||||
tokenized_data = data.map(
|
tokenized_data = data.map(
|
||||||
map_fn,
|
map_fn,
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ from trl.trainer.utils import RewardDataCollatorWithPadding
|
|||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers.base import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
|
AxolotlKTOTrainer,
|
||||||
AxolotlMambaTrainer,
|
AxolotlMambaTrainer,
|
||||||
AxolotlORPOTrainer,
|
AxolotlORPOTrainer,
|
||||||
AxolotlPRMTrainer,
|
AxolotlPRMTrainer,
|
||||||
@@ -50,7 +51,6 @@ from axolotl.core.trainers.base import (
|
|||||||
from axolotl.core.trainers.dpo import DPOStrategy
|
from axolotl.core.trainers.dpo import DPOStrategy
|
||||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||||
from axolotl.core.trainers.kto import AxolotlKTOTrainer
|
|
||||||
from axolotl.core.training_args import (
|
from axolotl.core.training_args import (
|
||||||
AxolotlCPOConfig,
|
AxolotlCPOConfig,
|
||||||
AxolotlKTOConfig,
|
AxolotlKTOConfig,
|
||||||
|
|||||||
@@ -20,10 +20,9 @@ from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sequential
|
|||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import CPOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.core.trainers.kto import AxolotlKTOTrainer
|
|
||||||
from axolotl.integrations.base import BaseOptimizerFactory
|
from axolotl.integrations.base import BaseOptimizerFactory
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
@@ -875,6 +874,14 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
|||||||
tag_names = ["axolotl", "orpo"]
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base KTOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base CPOTrainer for axolotl helpers
|
Extend the base CPOTrainer for axolotl helpers
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
"""
|
|
||||||
KTO package initialization.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from axolotl.core.trainers.kto.trainer import AxolotlKTOTrainer
|
|
||||||
|
|
||||||
__all__ = ["AxolotlKTOTrainer"]
|
|
||||||
@@ -1,512 +0,0 @@
|
|||||||
"""
|
|
||||||
KTO trainer implementation for Axolotl.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
import logging
|
|
||||||
import warnings
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from typing import Any, Callable, Literal, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from datasets import Dataset
|
|
||||||
from torch.utils.data import DataLoader, SequentialSampler
|
|
||||||
from transformers import (
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
BaseImageProcessor,
|
|
||||||
DataCollator,
|
|
||||||
FeatureExtractionMixin,
|
|
||||||
PreTrainedModel,
|
|
||||||
PreTrainedTokenizerBase,
|
|
||||||
ProcessorMixin,
|
|
||||||
Trainer,
|
|
||||||
TrainerCallback,
|
|
||||||
TrainingArguments,
|
|
||||||
)
|
|
||||||
from transformers.trainer_utils import EvalLoopOutput
|
|
||||||
from trl import KTOTrainer
|
|
||||||
from trl.trainer.kto_config import KTOConfig
|
|
||||||
from trl.trainer.utils import KTODataCollatorWithPadding, pad_to_length
|
|
||||||
|
|
||||||
from axolotl.core.trainers.base import SchedulerMixin
|
|
||||||
|
|
||||||
# Check if PEFT is available
|
|
||||||
try:
|
|
||||||
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training, peft_module_casting_to_bf16
|
|
||||||
is_peft_available = True
|
|
||||||
except ImportError:
|
|
||||||
is_peft_available = False
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.core.trainers.kto")
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(SchedulerMixin, Trainer):
|
|
||||||
"""
|
|
||||||
Extend the base KTOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: Union[PreTrainedModel, nn.Module, str] = None,
|
|
||||||
args: KTOConfig = None,
|
|
||||||
train_dataset: Optional[Dataset] = None,
|
|
||||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
|
||||||
processing_class: Optional[
|
|
||||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
|
||||||
] = None,
|
|
||||||
data_collator: Optional[DataCollator] = None,
|
|
||||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
|
||||||
callbacks: Optional[list[TrainerCallback]] = None,
|
|
||||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
|
||||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
|
||||||
peft_config: Optional[dict] = None,
|
|
||||||
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
|
||||||
dataset_tags=None,
|
|
||||||
model_adapter_name: Optional[str] = None,
|
|
||||||
ref_adapter_name: Optional[str] = None,
|
|
||||||
):
|
|
||||||
self.dataset_tags = dataset_tags
|
|
||||||
self._tag_names = ["trl", "kto"]
|
|
||||||
if hasattr(self, "tag_names"):
|
|
||||||
self._tag_names.extend(self.tag_names)
|
|
||||||
|
|
||||||
if type(args) is TrainingArguments:
|
|
||||||
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
|
|
||||||
|
|
||||||
if args.model_init_kwargs is None:
|
|
||||||
model_init_kwargs = {}
|
|
||||||
elif not isinstance(model, str):
|
|
||||||
raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
|
|
||||||
else:
|
|
||||||
model_init_kwargs = args.model_init_kwargs
|
|
||||||
torch_dtype = model_init_kwargs.get("torch_dtype")
|
|
||||||
if torch_dtype is not None:
|
|
||||||
# Convert to `torch.dtype` if an str is passed
|
|
||||||
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
|
||||||
torch_dtype = getattr(torch, torch_dtype)
|
|
||||||
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
|
||||||
)
|
|
||||||
model_init_kwargs["torch_dtype"] = torch_dtype
|
|
||||||
|
|
||||||
if args.ref_model_init_kwargs is None:
|
|
||||||
ref_model_init_kwargs = {}
|
|
||||||
elif not isinstance(ref_model, str):
|
|
||||||
raise ValueError(
|
|
||||||
"You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ref_model_init_kwargs = args.ref_model_init_kwargs
|
|
||||||
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
|
|
||||||
if torch_dtype is not None:
|
|
||||||
# Convert to `torch.dtype` if an str is passed
|
|
||||||
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
|
||||||
torch_dtype = getattr(torch, torch_dtype)
|
|
||||||
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
|
||||||
)
|
|
||||||
ref_model_init_kwargs["torch_dtype"] = torch_dtype
|
|
||||||
|
|
||||||
if isinstance(model, str):
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
|
||||||
|
|
||||||
if isinstance(ref_model, str):
|
|
||||||
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
|
|
||||||
|
|
||||||
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
|
||||||
# has been called in order to properly call autocast if needed.
|
|
||||||
self._peft_has_been_casted_to_bf16 = False
|
|
||||||
|
|
||||||
if not is_peft_available() and peft_config is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
|
||||||
)
|
|
||||||
elif is_peft_available() and peft_config is not None:
|
|
||||||
# if model is a peft model and we have a peft_config, we merge and unload it first
|
|
||||||
if isinstance(model, PeftModel):
|
|
||||||
model = model.merge_and_unload()
|
|
||||||
|
|
||||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
|
||||||
_support_gc_kwargs = hasattr(
|
|
||||||
args, "gradient_checkpointing_kwargs"
|
|
||||||
) and "gradient_checkpointing_kwargs" in list(
|
|
||||||
inspect.signature(prepare_model_for_kbit_training).parameters
|
|
||||||
)
|
|
||||||
|
|
||||||
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
|
||||||
|
|
||||||
if _support_gc_kwargs:
|
|
||||||
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
|
||||||
|
|
||||||
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
|
||||||
elif getattr(args, "gradient_checkpointing", False):
|
|
||||||
# For backward compatibility with older versions of transformers
|
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
|
||||||
model.enable_input_require_grads()
|
|
||||||
else:
|
|
||||||
|
|
||||||
def make_inputs_require_grad(module, input, output):
|
|
||||||
output.requires_grad_(True)
|
|
||||||
|
|
||||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
|
||||||
|
|
||||||
# get peft model with the given config
|
|
||||||
model = get_peft_model(model, peft_config)
|
|
||||||
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
|
||||||
peft_module_casting_to_bf16(model)
|
|
||||||
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
|
||||||
self._peft_has_been_casted_to_bf16 = True
|
|
||||||
|
|
||||||
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
|
||||||
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
|
||||||
# fail or completely fail.
|
|
||||||
elif getattr(args, "gradient_checkpointing", False):
|
|
||||||
# For backward compatibility with older versions of transformers
|
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
|
||||||
model.enable_input_require_grads()
|
|
||||||
else:
|
|
||||||
|
|
||||||
def make_inputs_require_grad(module, input, output):
|
|
||||||
output.requires_grad_(True)
|
|
||||||
|
|
||||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
|
||||||
|
|
||||||
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
|
||||||
raise ValueError(
|
|
||||||
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
|
||||||
" Please install `wandb` or `comet-ml` to resolve."
|
|
||||||
)
|
|
||||||
|
|
||||||
if model is not None:
|
|
||||||
self.is_encoder_decoder = model.config.is_encoder_decoder
|
|
||||||
elif args.is_encoder_decoder is None:
|
|
||||||
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
|
||||||
else:
|
|
||||||
self.is_encoder_decoder = args.is_encoder_decoder
|
|
||||||
|
|
||||||
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
|
||||||
self.model_adapter_name = model_adapter_name
|
|
||||||
self.ref_adapter_name = ref_adapter_name
|
|
||||||
|
|
||||||
if ref_model:
|
|
||||||
self.ref_model = ref_model
|
|
||||||
elif self.is_peft_model or args.precompute_ref_log_probs:
|
|
||||||
# The `model` with adapters turned off will be used as the reference model
|
|
||||||
self.ref_model = None
|
|
||||||
else:
|
|
||||||
self.ref_model = create_reference_model(model)
|
|
||||||
|
|
||||||
if processing_class is None:
|
|
||||||
raise ValueError(
|
|
||||||
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
|
|
||||||
)
|
|
||||||
if args.max_length is None:
|
|
||||||
warnings.warn(
|
|
||||||
"When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
|
|
||||||
" it will be set to `512` by default, but you should do it yourself in the future.",
|
|
||||||
UserWarning,
|
|
||||||
)
|
|
||||||
max_length = 512
|
|
||||||
if args.max_length is not None:
|
|
||||||
max_length = args.max_length
|
|
||||||
|
|
||||||
if args.max_prompt_length is None:
|
|
||||||
warnings.warn(
|
|
||||||
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
|
|
||||||
" it will be set to `128` by default, but you should do it yourself in the future.",
|
|
||||||
UserWarning,
|
|
||||||
)
|
|
||||||
max_prompt_length = 128
|
|
||||||
if args.max_prompt_length is not None:
|
|
||||||
max_prompt_length = args.max_prompt_length
|
|
||||||
|
|
||||||
max_completion_length = None
|
|
||||||
if args.max_completion_length is None and self.is_encoder_decoder:
|
|
||||||
warnings.warn(
|
|
||||||
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
|
|
||||||
" it will be set to `128` by default, but you should do it yourself in the future.",
|
|
||||||
UserWarning,
|
|
||||||
)
|
|
||||||
max_completion_length = 128
|
|
||||||
if args.max_completion_length is not None and self.is_encoder_decoder:
|
|
||||||
max_completion_length = args.max_completion_length
|
|
||||||
|
|
||||||
if data_collator is None:
|
|
||||||
data_collator = DPODataCollatorWithPadding(
|
|
||||||
pad_token_id=processing_class.pad_token_id,
|
|
||||||
label_pad_token_id=args.label_pad_token_id,
|
|
||||||
is_encoder_decoder=self.is_encoder_decoder,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.remove_unused_columns:
|
|
||||||
args.remove_unused_columns = False
|
|
||||||
# warn users
|
|
||||||
warnings.warn(
|
|
||||||
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
|
|
||||||
" we have set it for you, but you should do it yourself in the future.",
|
|
||||||
UserWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.use_dpo_data_collator = True
|
|
||||||
else:
|
|
||||||
self.use_dpo_data_collator = False
|
|
||||||
|
|
||||||
# Disable dropout in the model and reference model
|
|
||||||
if args.disable_dropout:
|
|
||||||
disable_dropout_in_model(model)
|
|
||||||
if self.ref_model is not None:
|
|
||||||
disable_dropout_in_model(self.ref_model)
|
|
||||||
|
|
||||||
self.loss_type = args.loss_type
|
|
||||||
self.max_length = max_length
|
|
||||||
self.generate_during_eval = args.generate_during_eval
|
|
||||||
self.label_pad_token_id = args.label_pad_token_id
|
|
||||||
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
|
||||||
self.max_prompt_length = max_prompt_length
|
|
||||||
self.truncation_mode = args.truncation_mode
|
|
||||||
self.max_completion_length = max_completion_length
|
|
||||||
self.processing_class = processing_class
|
|
||||||
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
|
||||||
|
|
||||||
# Not all losses require a KL calculation
|
|
||||||
self.calculate_KL = True
|
|
||||||
if self.loss_type in ["apo_zero_unpaired"]:
|
|
||||||
self.calculate_KL = False
|
|
||||||
|
|
||||||
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
|
||||||
# keep track of first called to avoid computation of future calls
|
|
||||||
self._precomputed_train_ref_log_probs = False
|
|
||||||
self._precomputed_eval_ref_log_probs = False
|
|
||||||
|
|
||||||
# metric
|
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
|
||||||
|
|
||||||
# KTO parameter
|
|
||||||
self.beta = args.beta
|
|
||||||
self.desirable_weight = args.desirable_weight
|
|
||||||
self.undesirable_weight = args.undesirable_weight
|
|
||||||
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
|
||||||
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
|
||||||
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
|
||||||
warnings.warn(
|
|
||||||
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
|
||||||
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
|
||||||
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
|
||||||
"loss.",
|
|
||||||
UserWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
|
||||||
# input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
|
|
||||||
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
|
|
||||||
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
|
||||||
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
|
||||||
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
|
||||||
# issued.
|
|
||||||
model.warnings_issued["estimate_tokens"] = True
|
|
||||||
|
|
||||||
# Compute that only on the main process for faster data processing.
|
|
||||||
# see: https://github.com/huggingface/trl/pull/1255
|
|
||||||
with PartialState().local_main_process_first():
|
|
||||||
# Extract the prompt if needed
|
|
||||||
train_dataset = train_dataset.map(
|
|
||||||
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
|
|
||||||
)
|
|
||||||
# Unpair the dataset if needed
|
|
||||||
train_dataset = maybe_unpair_preference_dataset(
|
|
||||||
train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
|
|
||||||
)
|
|
||||||
# Apply the chat template if needed
|
|
||||||
train_dataset = train_dataset.map(
|
|
||||||
maybe_apply_chat_template,
|
|
||||||
fn_kwargs={"tokenizer": processing_class},
|
|
||||||
num_proc=args.dataset_num_proc,
|
|
||||||
desc="Applying chat template to train dataset",
|
|
||||||
)
|
|
||||||
if eval_dataset is not None:
|
|
||||||
eval_dataset = eval_dataset.map(
|
|
||||||
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
|
|
||||||
)
|
|
||||||
eval_dataset = maybe_unpair_preference_dataset(
|
|
||||||
eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
|
|
||||||
)
|
|
||||||
eval_dataset = eval_dataset.map(
|
|
||||||
maybe_apply_chat_template,
|
|
||||||
fn_kwargs={"tokenizer": processing_class},
|
|
||||||
num_proc=args.dataset_num_proc,
|
|
||||||
desc="Applying chat template to eval dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Tokenize and prepare the training datasets
|
|
||||||
train_dataset = train_dataset.map(
|
|
||||||
_tokenize,
|
|
||||||
batched=True,
|
|
||||||
fn_kwargs={"tokenizer": self.processing_class},
|
|
||||||
num_proc=args.dataset_num_proc,
|
|
||||||
desc="Tokenizing train dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
fn_kwargs = {
|
|
||||||
"prefix": "",
|
|
||||||
"is_encoder_decoder": self.is_encoder_decoder,
|
|
||||||
"tokenizer": self.processing_class,
|
|
||||||
"max_length": self.max_length,
|
|
||||||
"truncation_mode": self.truncation_mode,
|
|
||||||
"label_pad_token_id": self.label_pad_token_id,
|
|
||||||
"max_prompt_length": self.max_prompt_length,
|
|
||||||
"max_completion_length": self.max_completion_length,
|
|
||||||
}
|
|
||||||
|
|
||||||
train_dataset = train_dataset.map(
|
|
||||||
_process_tokens,
|
|
||||||
fn_kwargs=fn_kwargs,
|
|
||||||
num_proc=args.dataset_num_proc,
|
|
||||||
desc="Processing tokenized train dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Tokenize and prepare the eval datasets
|
|
||||||
if eval_dataset is not None:
|
|
||||||
eval_dataset = eval_dataset.map(
|
|
||||||
_tokenize,
|
|
||||||
fn_kwargs={"tokenizer": self.processing_class},
|
|
||||||
batched=True,
|
|
||||||
num_proc=args.dataset_num_proc,
|
|
||||||
desc="Tokenizing eval dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
eval_dataset = eval_dataset.map(
|
|
||||||
_process_tokens,
|
|
||||||
fn_kwargs=fn_kwargs,
|
|
||||||
num_proc=args.dataset_num_proc,
|
|
||||||
desc="Processing tokenized eval dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get KL datasets if needed
|
|
||||||
if self.calculate_KL:
|
|
||||||
if args.per_device_train_batch_size <= 1:
|
|
||||||
raise ValueError(
|
|
||||||
"Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
|
|
||||||
)
|
|
||||||
|
|
||||||
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
|
|
||||||
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
|
|
||||||
train_kl_dataset = train_dataset.map(
|
|
||||||
_get_kl_dataset,
|
|
||||||
batched=True,
|
|
||||||
batch_size=args.per_device_train_batch_size,
|
|
||||||
num_proc=args.dataset_num_proc,
|
|
||||||
desc="Extracting KL train dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
fn_kwargs["prefix"] = "KL_"
|
|
||||||
train_kl_dataset = train_kl_dataset.map(
|
|
||||||
_process_tokens,
|
|
||||||
fn_kwargs=fn_kwargs,
|
|
||||||
num_proc=args.dataset_num_proc,
|
|
||||||
remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
|
|
||||||
desc="Processing tokenized train KL dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
# merge the datasets
|
|
||||||
train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
|
|
||||||
|
|
||||||
if eval_dataset is not None:
|
|
||||||
# Get KL dataset
|
|
||||||
eval_kl_dataset = eval_dataset.map(
|
|
||||||
_get_kl_dataset,
|
|
||||||
batched=True,
|
|
||||||
batch_size=args.per_device_train_batch_size,
|
|
||||||
num_proc=args.dataset_num_proc,
|
|
||||||
desc="Extracting eval KL dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
eval_kl_dataset = eval_kl_dataset.map(
|
|
||||||
_process_tokens,
|
|
||||||
fn_kwargs=fn_kwargs,
|
|
||||||
num_proc=args.dataset_num_proc,
|
|
||||||
remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
|
|
||||||
desc="Processing tokenized eval KL dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
# merge the datasets
|
|
||||||
eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
|
|
||||||
|
|
||||||
# calculate dataset desirability balance
|
|
||||||
num_desirable = max(sum(train_dataset["label"]), 1)
|
|
||||||
num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
|
|
||||||
|
|
||||||
if num_desirable != num_undesirable:
|
|
||||||
# The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
|
|
||||||
des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
|
|
||||||
des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
|
|
||||||
und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
|
|
||||||
und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
|
|
||||||
|
|
||||||
des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
|
|
||||||
und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
|
|
||||||
|
|
||||||
if not (des_weight_in_range or und_weight_in_range):
|
|
||||||
warnings.warn(
|
|
||||||
"You have different amounts of desirable/positive and undesirable/negative examples but the "
|
|
||||||
"weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
|
|
||||||
f"on your data, we recommend EITHER "
|
|
||||||
f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
|
|
||||||
f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
|
|
||||||
"See the documentation on how to optimally set these weights.",
|
|
||||||
UserWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
model=model,
|
|
||||||
args=args,
|
|
||||||
data_collator=data_collator,
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
eval_dataset=eval_dataset,
|
|
||||||
processing_class=processing_class,
|
|
||||||
model_init=model_init,
|
|
||||||
compute_metrics=compute_metrics,
|
|
||||||
callbacks=callbacks,
|
|
||||||
optimizers=optimizers,
|
|
||||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
|
||||||
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
|
||||||
# self.model_accepts_loss_kwargs to False to enable scaling.
|
|
||||||
self.model_accepts_loss_kwargs = False
|
|
||||||
|
|
||||||
# Add tags for models that have been loaded with the correct transformers version
|
|
||||||
if hasattr(self.model, "add_model_tags"):
|
|
||||||
self.model.add_model_tags(self._tag_names)
|
|
||||||
|
|
||||||
if not hasattr(self, "accelerator"):
|
|
||||||
raise AttributeError(
|
|
||||||
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
|
||||||
if self.is_deepspeed_enabled:
|
|
||||||
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.ref_model is None:
|
|
||||||
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
|
||||||
raise ValueError(
|
|
||||||
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if self.is_deepspeed_enabled:
|
|
||||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
|
||||||
else:
|
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
|
||||||
@@ -17,7 +17,7 @@ Run the following command to install `cut_cross_entropy[transformers]` if you do
|
|||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|
||||||
# if you are not in dev environment
|
# if you are not in dev environment
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install cut_cross_entropy with transformers support using "
|
"Please install cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`'
|
'`pip install "cut-cross-entropy[transformers]==24.11.4"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ Module for handling Spectrum input arguments.
|
|||||||
"""
|
"""
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class SpectrumArgs(BaseModel):
|
class SpectrumArgs(BaseModel):
|
||||||
@@ -27,20 +27,3 @@ class SpectrumArgs(BaseModel):
|
|||||||
|
|
||||||
spectrum_top_fraction: Optional[float] = 0.5
|
spectrum_top_fraction: Optional[float] = 0.5
|
||||||
spectrum_model_name: Optional[str] = None
|
spectrum_model_name: Optional[str] = None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_fsdp_use_orig_params(cls, data):
|
|
||||||
if (
|
|
||||||
data.get("fsdp")
|
|
||||||
and data.get("fsdp_config")
|
|
||||||
and not data["fsdp_config"].get("use_orig_params")
|
|
||||||
and data.get("plugins")
|
|
||||||
and any("SpectrumPlugin" in plugin for plugin in data["plugins"])
|
|
||||||
):
|
|
||||||
# would otherwise raise
|
|
||||||
# ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False`
|
|
||||||
raise ValueError(
|
|
||||||
"FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set"
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
import weakref
|
import weakref
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers.modelcard
|
import transformers.modelcard
|
||||||
@@ -20,7 +20,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
|||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
from axolotl.common.datasets import TrainDatasetMeta
|
from axolotl.common.datasets import TrainDatasetMeta
|
||||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
)
|
)
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
@@ -382,23 +382,21 @@ def handle_untrained_tokens_fix(
|
|||||||
if not cfg.fix_untrained_tokens:
|
if not cfg.fix_untrained_tokens:
|
||||||
return
|
return
|
||||||
|
|
||||||
is_ds_zero3: bool = False
|
|
||||||
if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
|
|
||||||
is_ds_zero3 = True
|
|
||||||
|
|
||||||
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
||||||
sig = inspect.signature(fix_untrained_tokens)
|
sig = inspect.signature(fix_untrained_tokens)
|
||||||
|
|
||||||
fix_kwargs: Dict[str, Any] = {}
|
|
||||||
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
||||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
if "token_ids_to_fix" in sig.parameters and isinstance(
|
||||||
cfg.fix_untrained_tokens, list
|
cfg.fix_untrained_tokens, list
|
||||||
):
|
):
|
||||||
fix_kwargs["token_ids_to_fix"] = cfg.fix_untrained_tokens
|
fix_untrained_tokens(
|
||||||
if "is_ds_zero3" in sig.parameters:
|
model,
|
||||||
fix_kwargs["is_ds_zero3"] = is_ds_zero3
|
tokenizer,
|
||||||
|
train_dataset,
|
||||||
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
|
token_ids_to_fix=cfg.fix_untrained_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ class CustomSupportedOptimizers(str, Enum):
|
|||||||
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
||||||
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||||
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||||
|
lion_pytorch = "lion_pytorch" # pylint: disable=invalid-name
|
||||||
muon = "muon" # pylint: disable=invalid-name
|
muon = "muon" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
@@ -728,7 +729,7 @@ class AxolotlInputConfig(
|
|||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
||||||
)
|
)
|
||||||
dataset_processes: Optional[int] = Field(default=min(32, os.cpu_count())) # type: ignore[type-var]
|
dataset_processes: Optional[int] = Field(default=os.cpu_count())
|
||||||
dataset_exact_deduplication: Optional[bool] = None
|
dataset_exact_deduplication: Optional[bool] = None
|
||||||
dataset_keep_in_memory: Optional[bool] = None
|
dataset_keep_in_memory: Optional[bool] = None
|
||||||
dataloader_pin_memory: Optional[bool] = None
|
dataloader_pin_memory: Optional[bool] = None
|
||||||
@@ -779,9 +780,9 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
# torch_dtype: Optional[torch.dtype]
|
# torch_dtype: Optional[torch.dtype]
|
||||||
|
|
||||||
gradient_checkpointing: Optional[
|
gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
|
||||||
Union[Literal["unsloth", "offload"], bool]
|
default=False
|
||||||
] = Field(default=False)
|
)
|
||||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
unfrozen_parameters: Optional[List[str]] = None
|
unfrozen_parameters: Optional[List[str]] = None
|
||||||
@@ -856,7 +857,6 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
special_tokens: Optional[SpecialTokensConfig] = None
|
special_tokens: Optional[SpecialTokensConfig] = None
|
||||||
tokens: Optional[List[str]] = None
|
tokens: Optional[List[str]] = None
|
||||||
added_tokens_overrides: Optional[Dict[int, str]] = None
|
|
||||||
|
|
||||||
torch_compile: Optional[Union[Literal["auto"], bool]] = None
|
torch_compile: Optional[Union[Literal["auto"], bool]] = None
|
||||||
torch_compile_backend: Optional[str] = None
|
torch_compile_backend: Optional[str] = None
|
||||||
@@ -1155,15 +1155,6 @@ class AxolotlInputConfig(
|
|||||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def check_offload_grad_checkpointing(self):
|
|
||||||
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
|
|
||||||
LOG.warning(
|
|
||||||
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
|
|
||||||
)
|
|
||||||
self.gradient_checkpointing = "offload"
|
|
||||||
return self
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_better_transformers(self):
|
def check_better_transformers(self):
|
||||||
if self.flash_optimum is True:
|
if self.flash_optimum is True:
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
GRPO specific configuration args
|
GRPO specific configuration args
|
||||||
"""
|
"""
|
||||||
|
from typing import List, Optional
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -12,10 +11,7 @@ class TRLConfig(BaseModel):
|
|||||||
Input args for TRL.
|
Input args for TRL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
beta: Optional[float] = Field(
|
beta: Optional[float] = None
|
||||||
default=None,
|
|
||||||
json_schema_extra={"description": "Beta for RL training"},
|
|
||||||
)
|
|
||||||
max_completion_length: Optional[int] = Field(
|
max_completion_length: Optional[int] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -24,68 +20,17 @@ class TRLConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# GRPO specific args
|
# GRPO specific args
|
||||||
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
use_vllm: Optional[bool] = False
|
||||||
use_vllm: Optional[bool] = Field(
|
vllm_device: Optional[str] = "auto"
|
||||||
default=False,
|
vllm_gpu_memory_utilization: Optional[float] = 0.9
|
||||||
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
vllm_max_model_len: Optional[int] = None
|
||||||
)
|
vllm_dtype: Optional[str] = "auto"
|
||||||
vllm_device: Optional[str] = Field(
|
|
||||||
default="auto",
|
|
||||||
json_schema_extra={"description": "Device to use for VLLM"},
|
|
||||||
)
|
|
||||||
vllm_gpu_memory_utilization: Optional[float] = Field(
|
|
||||||
default=0.9,
|
|
||||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
|
||||||
)
|
|
||||||
vllm_dtype: Optional[str] = Field(
|
|
||||||
default="auto",
|
|
||||||
json_schema_extra={"description": "Data type for VLLM"},
|
|
||||||
)
|
|
||||||
vllm_max_model_len: Optional[int] = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Maximum length of the model context for VLLM"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
reward_funcs: Optional[list[str]] = Field(
|
reward_funcs: Optional[List[str]] = None
|
||||||
default=None,
|
reward_weights: Optional[List[float]] = None
|
||||||
json_schema_extra={"description": "List of reward functions to load"},
|
num_generations: Optional[int] = None
|
||||||
)
|
log_completions: Optional[bool] = False
|
||||||
reward_weights: Optional[list[float]] = Field(
|
|
||||||
default=None,
|
sync_ref_model: Optional[bool] = False
|
||||||
json_schema_extra={
|
ref_model_mixup_alpha: Optional[float] = 0.9
|
||||||
"description": "Weights for each reward function. Must match the number of reward functions."
|
ref_model_sync_steps: Optional[int] = 64
|
||||||
},
|
|
||||||
)
|
|
||||||
num_generations: Optional[int] = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
log_completions: Optional[bool] = Field(
|
|
||||||
default=False,
|
|
||||||
json_schema_extra={"description": "Whether to log completions"},
|
|
||||||
)
|
|
||||||
sync_ref_model: Optional[bool] = Field(
|
|
||||||
default=False,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": (
|
|
||||||
"Whether to sync the reference model every `ref_model_sync_steps` "
|
|
||||||
"steps, using the `ref_model_mixup_alpha` parameter."
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
ref_model_mixup_alpha: Optional[float] = Field(
|
|
||||||
default=0.9,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
ref_model_sync_steps: Optional[int] = Field(
|
|
||||||
default=64,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -121,7 +121,6 @@ def drop_long_rl_seq(
|
|||||||
|
|
||||||
|
|
||||||
def load_prepare_preference_datasets(cfg):
|
def load_prepare_preference_datasets(cfg):
|
||||||
import pdb; pdb.set_trace()
|
|
||||||
def load_split(dataset_cfgs, _cfg):
|
def load_split(dataset_cfgs, _cfg):
|
||||||
split_datasets: List[Any] = []
|
split_datasets: List[Any] = []
|
||||||
use_auth_token = _cfg.hf_use_auth_token
|
use_auth_token = _cfg.hf_use_auth_token
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ def is_main_process():
|
|||||||
|
|
||||||
|
|
||||||
def is_local_main_process():
|
def is_local_main_process():
|
||||||
return PartialState().is_local_main_process
|
return PartialState().is_main_process
|
||||||
|
|
||||||
|
|
||||||
def get_world_size():
|
def get_world_size():
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from axolotl.utils.gradient_checkpointing.unsloth import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def hf_grad_checkpoint_offload_wrapper(
|
def hf_grad_checkpoint_unsloth_wrapper(
|
||||||
decoder_layer, *args, use_reentrant=None
|
decoder_layer, *args, use_reentrant=None
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||||
|
|||||||
@@ -57,14 +57,8 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
|||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
|
||||||
barrier,
|
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
|
||||||
get_device_count,
|
|
||||||
get_device_type,
|
|
||||||
is_local_main_process,
|
|
||||||
zero_only,
|
|
||||||
)
|
|
||||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
|
||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
|
|
||||||
@@ -171,95 +165,7 @@ def load_model_config(cfg):
|
|||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
|
||||||
def modify_tokenizer_files(
|
|
||||||
tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Modify tokenizer files to replace added_tokens strings, save to output directory, and return the path to the modified tokenizer.
|
|
||||||
|
|
||||||
This only works with reserved tokens that were added to the tokenizer, not tokens already part of the vocab.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokenizer_path: Path or name of the original tokenizer
|
|
||||||
token_mappings: Dict mapping {token_id (int): new_token_string}
|
|
||||||
output_dir: Directory to save the modified tokenizer
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to the modified tokenizer directory
|
|
||||||
|
|
||||||
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
# Create the tokenizer directory in output_dir if it doesn't exist
|
|
||||||
tokenizer_dir = os.path.join(output_dir, "tokenizer")
|
|
||||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
|
||||||
|
|
||||||
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
|
|
||||||
# Load the tokenizer
|
|
||||||
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
|
|
||||||
|
|
||||||
# Save the tokenizer to the output directory
|
|
||||||
temp_tokenizer.save_pretrained(tokenizer_dir)
|
|
||||||
|
|
||||||
# Get the token IDs and map them to their new values
|
|
||||||
token_id_mappings = {
|
|
||||||
int(token_id): new_value for token_id, new_value in token_mappings.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
# 1. Update tokenizer_config.json - added_tokens_decoder
|
|
||||||
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
|
|
||||||
if os.path.exists(config_path):
|
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
|
||||||
config_data = json.load(f)
|
|
||||||
|
|
||||||
# Update added_tokens_decoder
|
|
||||||
if "added_tokens_decoder" in config_data:
|
|
||||||
for token_id, new_value in token_id_mappings.items():
|
|
||||||
token_id_str = str(token_id)
|
|
||||||
if token_id_str in config_data["added_tokens_decoder"]:
|
|
||||||
config_data["added_tokens_decoder"][token_id_str][
|
|
||||||
"content"
|
|
||||||
] = new_value
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Token ID {token_id_str} not found in added_tokens_decoder"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Write the updated config back
|
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(config_data, f, indent=2)
|
|
||||||
|
|
||||||
# 2. Update tokenizer.json - added_tokens
|
|
||||||
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
|
||||||
if os.path.exists(tokenizer_path):
|
|
||||||
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
|
||||||
tokenizer_data = json.load(f)
|
|
||||||
|
|
||||||
# Update added_tokens
|
|
||||||
if "added_tokens" in tokenizer_data:
|
|
||||||
for token_id, new_value in token_id_mappings.items():
|
|
||||||
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
|
|
||||||
if token_entry["id"] == token_id:
|
|
||||||
tokenizer_data["added_tokens"][i]["content"] = new_value
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
|
|
||||||
raise ValueError(
|
|
||||||
f"Token ID {token_id} not found in added_tokens"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Write the updated tokenizer data back
|
|
||||||
with open(tokenizer_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(tokenizer_data, f, indent=2)
|
|
||||||
|
|
||||||
barrier()
|
|
||||||
return tokenizer_dir
|
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(cfg):
|
def load_tokenizer(cfg):
|
||||||
"""Load and configure the tokenizer based on the provided config."""
|
|
||||||
model_config = load_model_config(cfg)
|
model_config = load_model_config(cfg)
|
||||||
tokenizer_kwargs = {}
|
tokenizer_kwargs = {}
|
||||||
use_fast = True # this is the default
|
use_fast = True # this is the default
|
||||||
@@ -274,18 +180,8 @@ def load_tokenizer(cfg):
|
|||||||
if cfg.tokenizer_type:
|
if cfg.tokenizer_type:
|
||||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||||
|
|
||||||
# Set base tokenizer path
|
|
||||||
tokenizer_path = cfg.tokenizer_config
|
|
||||||
|
|
||||||
# Apply token string overrides if specified
|
|
||||||
if cfg.added_tokens_overrides:
|
|
||||||
# Modify tokenizer files and get path to modified tokenizer
|
|
||||||
tokenizer_path = modify_tokenizer_files(
|
|
||||||
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = tokenizer_cls.from_pretrained(
|
tokenizer = tokenizer_cls.from_pretrained(
|
||||||
tokenizer_path,
|
cfg.tokenizer_config,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
use_fast=use_fast,
|
use_fast=use_fast,
|
||||||
**tokenizer_kwargs,
|
**tokenizer_kwargs,
|
||||||
@@ -493,8 +389,8 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_fa_peft_integration()
|
patch_fa_peft_integration()
|
||||||
|
|
||||||
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
if self.cfg.gradient_checkpointing == "unsloth":
|
||||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
||||||
|
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
self.patch_attention()
|
self.patch_attention()
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
h1 {
|
h1 {
|
||||||
font-family: var(--font-title);
|
font-family: var(--font-title);
|
||||||
font-weight: 400;
|
font-weight: 400;
|
||||||
font-size: 5rem;
|
font-size: 6rem;
|
||||||
line-height: 1.1;
|
line-height: 1.1;
|
||||||
letter-spacing: -0.05em;
|
letter-spacing: -0.05em;
|
||||||
font-feature-settings: "ss01" on;
|
font-feature-settings: "ss01" on;
|
||||||
|
|||||||
@@ -69,51 +69,6 @@ class TestCutCrossEntropyIntegration:
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
|
||||||
def test_qwen2_w_cce(self, temp_dir):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
|
||||||
"plugins": [
|
|
||||||
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
|
|
||||||
],
|
|
||||||
"cut_cross_entropy": True,
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"val_set_size": 0.1,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 4,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"save_safetensors": True,
|
|
||||||
"max_steps": 10,
|
|
||||||
"bf16": "auto",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
prepare_plugins(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
major, minor, _ = get_pytorch_version()
|
|
||||||
if (major, minor) < (2, 4):
|
|
||||||
with pytest.raises(ImportError):
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
else:
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
check_model_output_exists(temp_dir, cfg)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"attention_type",
|
"attention_type",
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -750,66 +750,3 @@ class TestMultiGPULlama:
|
|||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_fix_untrained_tokens(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"fix_untrained_tokens": True,
|
|
||||||
"sequence_len": 512,
|
|
||||||
"val_set_size": 0.0,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
"bos_token": "<|custom_im_start|>",
|
|
||||||
"eos_token": "<|custom_im_end|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"chat_template": "jinja",
|
|
||||||
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
|
|
||||||
"path": "mlabonne/FineTome-100k",
|
|
||||||
"type": "chat_template",
|
|
||||||
"split": "train[:10%]",
|
|
||||||
"field_messages": "conversations",
|
|
||||||
"message_field_role": "from",
|
|
||||||
"message_field_content": "value",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 5,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"sample_packing": True,
|
|
||||||
"bf16": True,
|
|
||||||
"save_safetensors": True,
|
|
||||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
|
|
||||||
"use_tensorboard": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# write cfg to yaml file
|
|
||||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
|
||||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
|
||||||
|
|
||||||
execute_subprocess_async(
|
|
||||||
[
|
|
||||||
"axolotl",
|
|
||||||
"train",
|
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
|
||||||
"--num-processes",
|
|
||||||
"2",
|
|
||||||
"--main-process-port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
check_tensorboard(
|
|
||||||
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -66,54 +66,6 @@ class TestLlama:
|
|||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
def test_fix_untrained_tokens(self, temp_dir):
|
def test_fix_untrained_tokens(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"fix_untrained_tokens": True,
|
|
||||||
"sequence_len": 512,
|
|
||||||
"val_set_size": 0.0,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
"bos_token": "<|custom_im_start|>",
|
|
||||||
"eos_token": "<|custom_im_end|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"chat_template": "jinja",
|
|
||||||
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
|
|
||||||
"path": "mlabonne/FineTome-100k",
|
|
||||||
"type": "chat_template",
|
|
||||||
"split": "train[:10%]",
|
|
||||||
"field_messages": "conversations",
|
|
||||||
"message_field_role": "from",
|
|
||||||
"message_field_content": "value",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 5,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"sample_packing": True,
|
|
||||||
"bf16": True,
|
|
||||||
"save_safetensors": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
check_model_output_exists(temp_dir, cfg)
|
|
||||||
|
|
||||||
def test_fix_untrained_tokens_already_trained(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Test cases for the tokenizer loading
|
Test cases for the tokenizer loading
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -10,7 +9,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
class TestTokenizers:
|
class TestTokenizers(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
test class for the load_tokenizer fn
|
test class for the load_tokenizer fn
|
||||||
"""
|
"""
|
||||||
@@ -76,48 +75,12 @@ class TestTokenizers:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
|
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
|
||||||
assert len(tokenizer) == 32001
|
self.assertEqual(len(tokenizer), 32001)
|
||||||
|
|
||||||
# ensure reloading the tokenizer again from cfg results in same vocab length
|
# ensure reloading the tokenizer again from cfg results in same vocab length
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
assert len(tokenizer) == 32001
|
self.assertEqual(len(tokenizer), 32001)
|
||||||
|
|
||||||
def test_added_tokens_overrides(self, temp_dir):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
# use with tokenizer that has reserved_tokens in added_tokens
|
|
||||||
"tokenizer_config": "NousResearch/Llama-3.2-1B",
|
|
||||||
"added_tokens_overrides": {
|
|
||||||
128041: "RANDOM_OVERRIDE_1",
|
|
||||||
128042: "RANDOM_OVERRIDE_2",
|
|
||||||
},
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
|
||||||
assert tokenizer.encode("RANDOM_OVERRIDE_1", add_special_tokens=False) == [
|
|
||||||
128041
|
|
||||||
]
|
|
||||||
assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [
|
|
||||||
128042
|
|
||||||
]
|
|
||||||
|
|
||||||
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
# use with tokenizer that has reserved_tokens in added_tokens
|
|
||||||
"tokenizer_config": "NousResearch/Llama-3.2-1B",
|
|
||||||
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError, match=r".*Token ID 1000000 not found in added_tokens.*"
|
|
||||||
):
|
|
||||||
load_tokenizer(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user