Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
76bb09784d fix import 2025-03-05 14:05:27 -05:00
Wing Lian
0542c7dd56 add muon optimizer
optimizer_cls_and_kwargs is on trainer_kwargs
only add adamw_kwargs if they're non-null
fix mocks
better handling of override and check the optimizer
unwrap optimizer
2025-03-05 10:47:22 -05:00
32 changed files with 63 additions and 1000 deletions

View File

@@ -88,11 +88,6 @@ jobs:
pytorch: 2.5.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -80,11 +80,6 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -14,7 +14,7 @@ COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
RUN apt install --yes --no-install-recommends openssh-server tmux && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \

View File

@@ -154,6 +154,8 @@ datasets:
content: value
# ...
message_property_mappings:
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles:
user: ["human", "user"]
@@ -554,13 +556,6 @@ special_tokens:
# Add extra tokens.
tokens:
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
# Can be checked if they exist in tokenizer.json added_tokens.
added_tokens_overrides: # Dict[int, str]
# 128041: "<|im_start|>"
# 128042: "<|im_end|>"
# FSDP
fsdp:
fsdp_config:

View File

@@ -74,10 +74,6 @@ datasets:
train_on_eos:
```
::: {.callout-tip}
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
:::
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml

View File

@@ -52,7 +52,3 @@ description: Frequently asked questions
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.

View File

@@ -28,17 +28,6 @@ val_set_size: 0.1
eval_steps: 100
```
Bradley-Terry chat templates expect single-turn conversations in the following format:
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
### Process Reward Models (PRM)
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
@@ -56,5 +45,3 @@ datasets:
val_set_size: 0.1
eval_steps: 100
```
Please see [stepwise_supervised](dataset-formats/stepwise_supervised.qmd) for more details on the dataset format.

View File

@@ -3,7 +3,6 @@ title: "RLHF (Beta)"
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
@@ -529,7 +528,6 @@ trl:
vllm_gpu_memory_utilization: 0.15
num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0]
datasets:
- path: openai/gsm8k
name: main
@@ -538,8 +536,6 @@ datasets:
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
### Using local dataset files
```yaml

View File

@@ -62,5 +62,5 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0
schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-lgpl==0.0.3
axolotl-contribs-mit==0.0.3

View File

@@ -24,5 +24,5 @@ if cce_spec:
print(
UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
)

View File

@@ -113,7 +113,7 @@ class ModalCloud(Cloud):
[
# Random id for cache busting of branch commits
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch} && git pull",
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
]
)
@@ -270,7 +270,6 @@ def _preprocess(config_yaml: str, volumes=None):
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"
@@ -289,7 +288,6 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
def _lm_eval(config_yaml: str, volumes=None):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"

View File

@@ -1,7 +1,6 @@
"""CLI to run training on a model."""
import logging
import os
from pathlib import Path
from typing import Union
@@ -35,8 +34,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
print_axolotl_text_art()
check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
check_user_token()
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -43,7 +43,7 @@ class TokenizedChatDataset(Dataset):
process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment]
)
num_proc = min(32, process_or_cpu_count)
num_proc = min(64, process_or_cpu_count)
features = data.features.keys()
tokenized_data = data.map(
map_fn,

View File

@@ -40,6 +40,7 @@ from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlMambaTrainer,
AxolotlORPOTrainer,
AxolotlPRMTrainer,
@@ -50,7 +51,6 @@ from axolotl.core.trainers.base import (
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.trainers.kto import AxolotlKTOTrainer
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlKTOConfig,

View File

@@ -20,10 +20,9 @@ from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sequential
from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import CPOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length
from axolotl.core.trainers.kto import AxolotlKTOTrainer
from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.monkeypatch.relora import ReLoRAScheduler
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -875,6 +874,14 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers

View File

@@ -1,7 +0,0 @@
"""
KTO package initialization.
"""
from axolotl.core.trainers.kto.trainer import AxolotlKTOTrainer
__all__ = ["AxolotlKTOTrainer"]

View File

@@ -1,512 +0,0 @@
"""
KTO trainer implementation for Axolotl.
"""
from __future__ import annotations
import inspect
import logging
import warnings
from contextlib import nullcontext
from typing import Any, Callable, Literal, Optional, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import Dataset
from torch.utils.data import DataLoader, SequentialSampler
from transformers import (
AutoModelForCausalLM,
BaseImageProcessor,
DataCollator,
FeatureExtractionMixin,
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
TrainerCallback,
TrainingArguments,
)
from transformers.trainer_utils import EvalLoopOutput
from trl import KTOTrainer
from trl.trainer.kto_config import KTOConfig
from trl.trainer.utils import KTODataCollatorWithPadding, pad_to_length
from axolotl.core.trainers.base import SchedulerMixin
# Check if PEFT is available
try:
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training, peft_module_casting_to_bf16
is_peft_available = True
except ImportError:
is_peft_available = False
LOG = logging.getLogger("axolotl.core.trainers.kto")
class AxolotlKTOTrainer(SchedulerMixin, Trainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
def __init__(
self,
model: Union[PreTrainedModel, nn.Module, str] = None,
args: KTOConfig = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
data_collator: Optional[DataCollator] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional[dict] = None,
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
dataset_tags=None,
model_adapter_name: Optional[str] = None,
ref_adapter_name: Optional[str] = None,
):
self.dataset_tags = dataset_tags
self._tag_names = ["trl", "kto"]
if hasattr(self, "tag_names"):
self._tag_names.extend(self.tag_names)
if type(args) is TrainingArguments:
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
if args.model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
else:
model_init_kwargs = args.model_init_kwargs
torch_dtype = model_init_kwargs.get("torch_dtype")
if torch_dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(torch_dtype, str) and torch_dtype != "auto":
torch_dtype = getattr(torch, torch_dtype)
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
)
model_init_kwargs["torch_dtype"] = torch_dtype
if args.ref_model_init_kwargs is None:
ref_model_init_kwargs = {}
elif not isinstance(ref_model, str):
raise ValueError(
"You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
)
else:
ref_model_init_kwargs = args.ref_model_init_kwargs
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
if torch_dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(torch_dtype, str) and torch_dtype != "auto":
torch_dtype = getattr(torch, torch_dtype)
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
)
ref_model_init_kwargs["torch_dtype"] = torch_dtype
if isinstance(model, str):
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
if isinstance(ref_model, str):
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
# has been called in order to properly call autocast if needed.
self._peft_has_been_casted_to_bf16 = False
if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
# if model is a peft model and we have a peft_config, we merge and unload it first
if isinstance(model, PeftModel):
model = model.merge_and_unload()
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
_support_gc_kwargs = hasattr(
args, "gradient_checkpointing_kwargs"
) and "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
if _support_gc_kwargs:
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
# get peft model with the given config
model = get_peft_model(model, peft_config)
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
peft_module_casting_to_bf16(model)
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
self._peft_has_been_casted_to_bf16 = True
# For models that use gradient_checkpointing, we need to attach a hook that enables input
# to explicitly have `requires_grad=True`, otherwise training will either silently
# fail or completely fail.
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
raise ValueError(
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
" Please install `wandb` or `comet-ml` to resolve."
)
if model is not None:
self.is_encoder_decoder = model.config.is_encoder_decoder
elif args.is_encoder_decoder is None:
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
else:
self.is_encoder_decoder = args.is_encoder_decoder
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
self.model_adapter_name = model_adapter_name
self.ref_adapter_name = ref_adapter_name
if ref_model:
self.ref_model = ref_model
elif self.is_peft_model or args.precompute_ref_log_probs:
# The `model` with adapters turned off will be used as the reference model
self.ref_model = None
else:
self.ref_model = create_reference_model(model)
if processing_class is None:
raise ValueError(
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
)
if args.max_length is None:
warnings.warn(
"When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
" it will be set to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
if args.max_length is not None:
max_length = args.max_length
if args.max_prompt_length is None:
warnings.warn(
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
" it will be set to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_prompt_length = 128
if args.max_prompt_length is not None:
max_prompt_length = args.max_prompt_length
max_completion_length = None
if args.max_completion_length is None and self.is_encoder_decoder:
warnings.warn(
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
" it will be set to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_completion_length = 128
if args.max_completion_length is not None and self.is_encoder_decoder:
max_completion_length = args.max_completion_length
if data_collator is None:
data_collator = DPODataCollatorWithPadding(
pad_token_id=processing_class.pad_token_id,
label_pad_token_id=args.label_pad_token_id,
is_encoder_decoder=self.is_encoder_decoder,
)
if args.remove_unused_columns:
args.remove_unused_columns = False
# warn users
warnings.warn(
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
" we have set it for you, but you should do it yourself in the future.",
UserWarning,
)
self.use_dpo_data_collator = True
else:
self.use_dpo_data_collator = False
# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)
self.loss_type = args.loss_type
self.max_length = max_length
self.generate_during_eval = args.generate_during_eval
self.label_pad_token_id = args.label_pad_token_id
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
self.max_prompt_length = max_prompt_length
self.truncation_mode = args.truncation_mode
self.max_completion_length = max_completion_length
self.processing_class = processing_class
self.precompute_ref_log_probs = args.precompute_ref_log_probs
# Not all losses require a KL calculation
self.calculate_KL = True
if self.loss_type in ["apo_zero_unpaired"]:
self.calculate_KL = False
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
# keep track of first called to avoid computation of future calls
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False
# metric
self._stored_metrics = defaultdict(lambda: defaultdict(list))
# KTO parameter
self.beta = args.beta
self.desirable_weight = args.desirable_weight
self.undesirable_weight = args.undesirable_weight
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
warnings.warn(
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
"loss.",
UserWarning,
)
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
# issued.
model.warnings_issued["estimate_tokens"] = True
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
# Extract the prompt if needed
train_dataset = train_dataset.map(
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
)
# Unpair the dataset if needed
train_dataset = maybe_unpair_preference_dataset(
train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
)
# Apply the chat template if needed
train_dataset = train_dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class},
num_proc=args.dataset_num_proc,
desc="Applying chat template to train dataset",
)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
)
eval_dataset = maybe_unpair_preference_dataset(
eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
)
eval_dataset = eval_dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class},
num_proc=args.dataset_num_proc,
desc="Applying chat template to eval dataset",
)
# Tokenize and prepare the training datasets
train_dataset = train_dataset.map(
_tokenize,
batched=True,
fn_kwargs={"tokenizer": self.processing_class},
num_proc=args.dataset_num_proc,
desc="Tokenizing train dataset",
)
fn_kwargs = {
"prefix": "",
"is_encoder_decoder": self.is_encoder_decoder,
"tokenizer": self.processing_class,
"max_length": self.max_length,
"truncation_mode": self.truncation_mode,
"label_pad_token_id": self.label_pad_token_id,
"max_prompt_length": self.max_prompt_length,
"max_completion_length": self.max_completion_length,
}
train_dataset = train_dataset.map(
_process_tokens,
fn_kwargs=fn_kwargs,
num_proc=args.dataset_num_proc,
desc="Processing tokenized train dataset",
)
# Tokenize and prepare the eval datasets
if eval_dataset is not None:
eval_dataset = eval_dataset.map(
_tokenize,
fn_kwargs={"tokenizer": self.processing_class},
batched=True,
num_proc=args.dataset_num_proc,
desc="Tokenizing eval dataset",
)
eval_dataset = eval_dataset.map(
_process_tokens,
fn_kwargs=fn_kwargs,
num_proc=args.dataset_num_proc,
desc="Processing tokenized eval dataset",
)
# Get KL datasets if needed
if self.calculate_KL:
if args.per_device_train_batch_size <= 1:
raise ValueError(
"Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
)
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
train_kl_dataset = train_dataset.map(
_get_kl_dataset,
batched=True,
batch_size=args.per_device_train_batch_size,
num_proc=args.dataset_num_proc,
desc="Extracting KL train dataset",
)
fn_kwargs["prefix"] = "KL_"
train_kl_dataset = train_kl_dataset.map(
_process_tokens,
fn_kwargs=fn_kwargs,
num_proc=args.dataset_num_proc,
remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
desc="Processing tokenized train KL dataset",
)
# merge the datasets
train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
if eval_dataset is not None:
# Get KL dataset
eval_kl_dataset = eval_dataset.map(
_get_kl_dataset,
batched=True,
batch_size=args.per_device_train_batch_size,
num_proc=args.dataset_num_proc,
desc="Extracting eval KL dataset",
)
eval_kl_dataset = eval_kl_dataset.map(
_process_tokens,
fn_kwargs=fn_kwargs,
num_proc=args.dataset_num_proc,
remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
desc="Processing tokenized eval KL dataset",
)
# merge the datasets
eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
# calculate dataset desirability balance
num_desirable = max(sum(train_dataset["label"]), 1)
num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
if num_desirable != num_undesirable:
# The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
if not (des_weight_in_range or und_weight_in_range):
warnings.warn(
"You have different amounts of desirable/positive and undesirable/negative examples but the "
"weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
f"on your data, we recommend EITHER "
f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
"See the documentation on how to optimally set these weights.",
UserWarning,
)
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
# self.model_accepts_loss_kwargs to False to enable scaling.
self.model_accepts_loss_kwargs = False
# Add tags for models that have been loaded with the correct transformers version
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)
if not hasattr(self, "accelerator"):
raise AttributeError(
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)
# Deepspeed Zero-3 does not support precompute_ref_log_probs
if self.is_deepspeed_enabled:
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
raise ValueError(
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
)
if self.ref_model is None:
if not (self.is_peft_model or self.precompute_ref_log_probs):
raise ValueError(
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
)
else:
if self.is_deepspeed_enabled:
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

View File

@@ -17,7 +17,7 @@ Run the following command to install `cut_cross_entropy[transformers]` if you do
python scripts/cutcrossentropy_install.py | sh
# if you are not in dev environment
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
```
## Usage

View File

@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
_CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`'
'`pip install "cut-cross-entropy[transformers]==24.11.4"`'
)

View File

@@ -17,7 +17,7 @@ Module for handling Spectrum input arguments.
"""
from typing import Optional
from pydantic import BaseModel, model_validator
from pydantic import BaseModel
class SpectrumArgs(BaseModel):
@@ -27,20 +27,3 @@ class SpectrumArgs(BaseModel):
spectrum_top_fraction: Optional[float] = 0.5
spectrum_model_name: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_fsdp_use_orig_params(cls, data):
if (
data.get("fsdp")
and data.get("fsdp_config")
and not data["fsdp_config"].get("use_orig_params")
and data.get("plugins")
and any("SpectrumPlugin" in plugin for plugin in data["plugins"])
):
# would otherwise raise
# ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False`
raise ValueError(
"FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set"
)
return data

View File

@@ -7,7 +7,7 @@ import signal
import sys
import weakref
from pathlib import Path
from typing import Any, Dict
from typing import Any
import torch
import transformers.modelcard
@@ -20,7 +20,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
@@ -382,23 +382,21 @@ def handle_untrained_tokens_fix(
if not cfg.fix_untrained_tokens:
return
is_ds_zero3: bool = False
if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
is_ds_zero3 = True
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
fix_kwargs: Dict[str, Any] = {}
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_kwargs["token_ids_to_fix"] = cfg.fix_untrained_tokens
if "is_ds_zero3" in sig.parameters:
fix_kwargs["is_ds_zero3"] = is_ds_zero3
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0:
model.save_pretrained(

View File

@@ -72,6 +72,7 @@ class CustomSupportedOptimizers(str, Enum):
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
lion_pytorch = "lion_pytorch" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name
@@ -728,7 +729,7 @@ class AxolotlInputConfig(
default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"},
)
dataset_processes: Optional[int] = Field(default=min(32, os.cpu_count())) # type: ignore[type-var]
dataset_processes: Optional[int] = Field(default=os.cpu_count())
dataset_exact_deduplication: Optional[bool] = None
dataset_keep_in_memory: Optional[bool] = None
dataloader_pin_memory: Optional[bool] = None
@@ -779,9 +780,9 @@ class AxolotlInputConfig(
# torch_dtype: Optional[torch.dtype]
gradient_checkpointing: Optional[
Union[Literal["unsloth", "offload"], bool]
] = Field(default=False)
gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
default=False
)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
unfrozen_parameters: Optional[List[str]] = None
@@ -856,7 +857,6 @@ class AxolotlInputConfig(
special_tokens: Optional[SpecialTokensConfig] = None
tokens: Optional[List[str]] = None
added_tokens_overrides: Optional[Dict[int, str]] = None
torch_compile: Optional[Union[Literal["auto"], bool]] = None
torch_compile_backend: Optional[str] = None
@@ -1155,15 +1155,6 @@ class AxolotlInputConfig(
raise ValueError("gradient_checkpointing is not supported for MPT models")
return self
@model_validator(mode="after")
def check_offload_grad_checkpointing(self):
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
LOG.warning(
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
)
self.gradient_checkpointing = "offload"
return self
@model_validator(mode="after")
def check_better_transformers(self):
if self.flash_optimum is True:

View File

@@ -1,8 +1,7 @@
"""
GRPO specific configuration args
"""
from typing import Optional
from typing import List, Optional
from pydantic import BaseModel, Field
@@ -12,10 +11,7 @@ class TRLConfig(BaseModel):
Input args for TRL.
"""
beta: Optional[float] = Field(
default=None,
json_schema_extra={"description": "Beta for RL training"},
)
beta: Optional[float] = None
max_completion_length: Optional[int] = Field(
default=None,
json_schema_extra={
@@ -24,68 +20,17 @@ class TRLConfig(BaseModel):
)
# GRPO specific args
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
use_vllm: Optional[bool] = Field(
default=False,
json_schema_extra={"description": "Whether to use VLLM for RL training"},
)
vllm_device: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Device to use for VLLM"},
)
vllm_gpu_memory_utilization: Optional[float] = Field(
default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"},
)
vllm_dtype: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Data type for VLLM"},
)
vllm_max_model_len: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the model context for VLLM"
},
)
use_vllm: Optional[bool] = False
vllm_device: Optional[str] = "auto"
vllm_gpu_memory_utilization: Optional[float] = 0.9
vllm_max_model_len: Optional[int] = None
vllm_dtype: Optional[str] = "auto"
reward_funcs: Optional[list[str]] = Field(
default=None,
json_schema_extra={"description": "List of reward functions to load"},
)
reward_weights: Optional[list[float]] = Field(
default=None,
json_schema_extra={
"description": "Weights for each reward function. Must match the number of reward functions."
},
)
num_generations: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
},
)
log_completions: Optional[bool] = Field(
default=False,
json_schema_extra={"description": "Whether to log completions"},
)
sync_ref_model: Optional[bool] = Field(
default=False,
json_schema_extra={
"description": (
"Whether to sync the reference model every `ref_model_sync_steps` "
"steps, using the `ref_model_mixup_alpha` parameter."
)
},
)
ref_model_mixup_alpha: Optional[float] = Field(
default=0.9,
json_schema_extra={
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
},
)
ref_model_sync_steps: Optional[int] = Field(
default=64,
json_schema_extra={
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
},
)
reward_funcs: Optional[List[str]] = None
reward_weights: Optional[List[float]] = None
num_generations: Optional[int] = None
log_completions: Optional[bool] = False
sync_ref_model: Optional[bool] = False
ref_model_mixup_alpha: Optional[float] = 0.9
ref_model_sync_steps: Optional[int] = 64

View File

@@ -121,7 +121,6 @@ def drop_long_rl_seq(
def load_prepare_preference_datasets(cfg):
import pdb; pdb.set_trace()
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
use_auth_token = _cfg.hf_use_auth_token

View File

@@ -79,7 +79,7 @@ def is_main_process():
def is_local_main_process():
return PartialState().is_local_main_process
return PartialState().is_main_process
def get_world_size():

View File

@@ -4,7 +4,7 @@ from axolotl.utils.gradient_checkpointing.unsloth import (
)
def hf_grad_checkpoint_offload_wrapper(
def hf_grad_checkpoint_unsloth_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply(

View File

@@ -57,14 +57,8 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
barrier,
get_device_count,
get_device_type,
is_local_main_process,
zero_only,
)
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
@@ -171,95 +165,7 @@ def load_model_config(cfg):
return model_config
def modify_tokenizer_files(
tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
) -> str:
"""
Modify tokenizer files to replace added_tokens strings, save to output directory, and return the path to the modified tokenizer.
This only works with reserved tokens that were added to the tokenizer, not tokens already part of the vocab.
Args:
tokenizer_path: Path or name of the original tokenizer
token_mappings: Dict mapping {token_id (int): new_token_string}
output_dir: Directory to save the modified tokenizer
Returns:
Path to the modified tokenizer directory
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
"""
import json
# Create the tokenizer directory in output_dir if it doesn't exist
tokenizer_dir = os.path.join(output_dir, "tokenizer")
os.makedirs(tokenizer_dir, exist_ok=True)
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
# Load the tokenizer
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
# Save the tokenizer to the output directory
temp_tokenizer.save_pretrained(tokenizer_dir)
# Get the token IDs and map them to their new values
token_id_mappings = {
int(token_id): new_value for token_id, new_value in token_mappings.items()
}
# 1. Update tokenizer_config.json - added_tokens_decoder
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
config_data = json.load(f)
# Update added_tokens_decoder
if "added_tokens_decoder" in config_data:
for token_id, new_value in token_id_mappings.items():
token_id_str = str(token_id)
if token_id_str in config_data["added_tokens_decoder"]:
config_data["added_tokens_decoder"][token_id_str][
"content"
] = new_value
else:
raise ValueError(
f"Token ID {token_id_str} not found in added_tokens_decoder"
)
# Write the updated config back
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
# 2. Update tokenizer.json - added_tokens
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
if os.path.exists(tokenizer_path):
with open(tokenizer_path, "r", encoding="utf-8") as f:
tokenizer_data = json.load(f)
# Update added_tokens
if "added_tokens" in tokenizer_data:
for token_id, new_value in token_id_mappings.items():
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
if token_entry["id"] == token_id:
tokenizer_data["added_tokens"][i]["content"] = new_value
break
else:
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
raise ValueError(
f"Token ID {token_id} not found in added_tokens"
)
# Write the updated tokenizer data back
with open(tokenizer_path, "w", encoding="utf-8") as f:
json.dump(tokenizer_data, f, indent=2)
barrier()
return tokenizer_dir
def load_tokenizer(cfg):
"""Load and configure the tokenizer based on the provided config."""
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
use_fast = True # this is the default
@@ -274,18 +180,8 @@ def load_tokenizer(cfg):
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
# Set base tokenizer path
tokenizer_path = cfg.tokenizer_config
# Apply token string overrides if specified
if cfg.added_tokens_overrides:
# Modify tokenizer files and get path to modified tokenizer
tokenizer_path = modify_tokenizer_files(
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
)
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_path,
cfg.tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
@@ -493,8 +389,8 @@ class ModelLoader:
patch_fa_peft_integration()
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if self.cfg.flash_attention:
self.patch_attention()

View File

@@ -14,7 +14,7 @@
h1 {
font-family: var(--font-title);
font-weight: 400;
font-size: 5rem;
font-size: 6rem;
line-height: 1.1;
letter-spacing: -0.05em;
font-feature-settings: "ss01" on;

View File

@@ -69,51 +69,6 @@ class TestCutCrossEntropyIntegration:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
# pylint: disable=redefined-outer-name
def test_qwen2_w_cce(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"plugins": [
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
],
"cut_cross_entropy": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"output_dir": temp_dir,
"lr_scheduler": "cosine",
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
}
)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, dataset_meta=dataset_meta)
else:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.parametrize(
"attention_type",
[

View File

@@ -750,66 +750,3 @@ class TestMultiGPULlama:
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
"use_tensorboard": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
)

View File

@@ -66,54 +66,6 @@ class TestLlama:
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens_already_trained(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{

View File

@@ -1,7 +1,6 @@
"""
Test cases for the tokenizer loading
"""
import unittest
import pytest
@@ -10,7 +9,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer
class TestTokenizers:
class TestTokenizers(unittest.TestCase):
"""
test class for the load_tokenizer fn
"""
@@ -76,48 +75,12 @@ class TestTokenizers:
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
assert len(tokenizer) == 32001
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
self.assertEqual(len(tokenizer), 32001)
# ensure reloading the tokenizer again from cfg results in same vocab length
tokenizer = load_tokenizer(cfg)
assert len(tokenizer) == 32001
def test_added_tokens_overrides(self, temp_dir):
cfg = DictDefault(
{
# use with tokenizer that has reserved_tokens in added_tokens
"tokenizer_config": "NousResearch/Llama-3.2-1B",
"added_tokens_overrides": {
128041: "RANDOM_OVERRIDE_1",
128042: "RANDOM_OVERRIDE_2",
},
"output_dir": temp_dir,
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer.encode("RANDOM_OVERRIDE_1", add_special_tokens=False) == [
128041
]
assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [
128042
]
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
cfg = DictDefault(
{
# use with tokenizer that has reserved_tokens in added_tokens
"tokenizer_config": "NousResearch/Llama-3.2-1B",
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
"output_dir": temp_dir,
}
)
with pytest.raises(
ValueError, match=r".*Token ID 1000000 not found in added_tokens.*"
):
load_tokenizer(cfg)
self.assertEqual(len(tokenizer), 32001)
if __name__ == "__main__":