Compare commits

..

17 Commits

Author SHA1 Message Date
Wing Lian
8fc4c420a4 Add kd coefficient scheduler 2025-03-18 09:01:58 -04:00
Wing Lian
4f5eb42a73 remove reference to deprecated import (#2407) 2025-03-15 08:49:41 -04:00
Wing Lian
fbe54be6b8 only validate hf user token on rank 0 (#2408) 2025-03-13 23:29:06 -04:00
Wing Lian
04f6324833 build cloud images with torch 2.6.0 (#2413)
* build cloud images with torch 2.6.0

* nightlies too
2025-03-13 23:28:51 -04:00
Wing Lian
f0072f3b9d use max of 32 dataset processes if not explicit (#2403)
* use max of 32 dataset processes if not explicit

* change alternate min val for consistency
2025-03-11 12:02:58 -04:00
Wing Lian
59899b9817 pass additional info for fix untrained tokens when using distributed + offloading (#2388)
* pass additional info for fix untrained tokens when using distributed + offloading

* use latest version of vendored lib

* use v0.0.5 of contribs lgpl

* fix for no bad tokens and add tests

* use release

* add multigpu test too

* make sure the multigpu zero3 test actually uses zero3
2025-03-11 12:02:43 -04:00
NanoCode012
4a736986fa fix(modal): add git pull when getting branch files (#2399) 2025-03-10 15:14:41 -04:00
Wing Lian
5d0f110a3b include iproute2 and nvtop in cloud image (#2393) 2025-03-10 15:13:38 -04:00
NanoCode012
83f8698b8a fix: create mount folder on modal if not exist (#2390) 2025-03-10 16:27:42 +07:00
xzuyn
60a11a6410 Use Latest Cut Cross Entropy (#2392)
* Update __init__.py

* Update README.md

* Update cutcrossentropy_install.py

* add test
2025-03-10 16:26:40 +07:00
NanoCode012
46a045e528 chore(doc): add faq when having no default chat_template (#2398)
* chore(doc): add faq when having no default chat_template

* Update docs/dataset-formats/conversation.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

* Update docs/faq.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-03-10 16:25:50 +07:00
NanoCode012
3b477e08a0 feat(doc): add more info on RewardModel datasets (#2391)
* fix: reduce title size

* feat(doc): add rm dataset info

* Update docs/reward_modelling.qmd following suggestion

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-03-10 16:25:31 +07:00
NanoCode012
16dc6ee68d refactor: trl grpo configs to have descriptions (#2386)
* refactor: trl grpo configs to have descriptions

* chore: caps
2025-03-07 08:58:53 -05:00
Wing Lian
fa7c79b3b9 remove lion-pytorch as it's already handled upstream (#2389) 2025-03-07 08:58:15 -05:00
Wing Lian
ae66374156 Optimizer refactor and add Muon support (#2367)
* add muon optimizer

optimizer_cls_and_kwargs is on trainer_kwargs
only add adamw_kwargs if they're non-null
fix mocks
better handling of override and check the optimizer
unwrap optimizer

* fix import
2025-03-06 11:49:19 -05:00
Wing Lian
5e21b1a9da various fixes 20250305 (#2384)
* various validation fixes

* fix check for non-truthy value
2025-03-06 11:48:44 -05:00
mhenrichsen
575e5f28ec Update Tokenizer Overrides Handling in models.py (#1549)
* override special tokens mock code

* fix(doc): remove duplicate config

* feat: replace added_tokens in tokenizer and add test

* make sure to run tokenizer modification on rank 0 only

* use is local main process instead

* feat: rename config

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-03-05 11:15:12 -05:00
36 changed files with 590 additions and 81 deletions

View File

@@ -88,6 +88,11 @@ jobs:
pytorch: 2.5.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -80,6 +80,11 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -14,7 +14,7 @@ COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux && \
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \

View File

@@ -154,8 +154,6 @@ datasets:
content: value
# ...
message_property_mappings:
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles:
user: ["human", "user"]
@@ -556,6 +554,13 @@ special_tokens:
# Add extra tokens.
tokens:
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
# Can be checked if they exist in tokenizer.json added_tokens.
added_tokens_overrides: # Dict[int, str]
# 128041: "<|im_start|>"
# 128042: "<|im_end|>"
# FSDP
fsdp:
fsdp_config:

View File

@@ -74,6 +74,10 @@ datasets:
train_on_eos:
```
::: {.callout-tip}
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
:::
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml

View File

@@ -52,3 +52,7 @@ description: Frequently asked questions
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.

View File

@@ -28,6 +28,17 @@ val_set_size: 0.1
eval_steps: 100
```
Bradley-Terry chat templates expect single-turn conversations in the following format:
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
### Process Reward Models (PRM)
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
@@ -45,3 +56,5 @@ datasets:
val_set_size: 0.1
eval_steps: 100
```
Please see [stepwise_supervised](dataset-formats/stepwise_supervised.qmd) for more details on the dataset format.

View File

@@ -3,6 +3,7 @@ title: "RLHF (Beta)"
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
@@ -528,6 +529,7 @@ trl:
vllm_gpu_memory_utilization: 0.15
num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0]
datasets:
- path: openai/gsm8k
name: main
@@ -536,6 +538,8 @@ datasets:
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
### Using local dataset files
```yaml

View File

@@ -62,5 +62,5 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0
schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.3
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3

View File

@@ -24,5 +24,5 @@ if cce_spec:
print(
UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
)

View File

@@ -113,7 +113,7 @@ class ModalCloud(Cloud):
[
# Random id for cache busting of branch commits
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch} && git pull",
]
)
@@ -270,6 +270,7 @@ def _preprocess(config_yaml: str, volumes=None):
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"
@@ -288,6 +289,7 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
def _lm_eval(config_yaml: str, volumes=None):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"

View File

@@ -1,6 +1,7 @@
"""CLI to run training on a model."""
import logging
import os
from pathlib import Path
from typing import Union
@@ -34,7 +35,8 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -43,7 +43,7 @@ class TokenizedChatDataset(Dataset):
process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment]
)
num_proc = min(64, process_or_cpu_count)
num_proc = min(32, process_or_cpu_count)
features = data.features.keys()
tokenized_data = data.map(
map_fn,

View File

@@ -751,8 +751,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_ce_alpha_end is not None:
training_arguments_kwargs["kd_ce_alpha_end"] = self.cfg.kd_ce_alpha_end
if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_alpha_end is not None:
training_arguments_kwargs["kd_alpha_end"] = self.cfg.kd_alpha_end
if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None:

View File

@@ -17,7 +17,7 @@ Run the following command to install `cut_cross_entropy[transformers]` if you do
python scripts/cutcrossentropy_install.py | sh
# if you are not in dev environment
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"
```
## Usage

View File

@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
_CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers]==24.11.4"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`'
)

View File

@@ -34,3 +34,12 @@ class KDPlugin(BasePlugin):
return AxolotlKDTrainer
return None
def add_callbacks_post_trainer(self, cfg, trainer):
callbacks = []
if cfg.kd_trainer:
from .callbacks import KDAlphaSchedulerCallback
callbacks.append(KDAlphaSchedulerCallback())
return callbacks

View File

@@ -30,6 +30,8 @@ class KDArgs(BaseModel):
float
] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_ce_alpha_end: Optional[float] = None # end value for kd_ce_alpha
kd_alpha_end: Optional[float] = None # end value for kd_alpha
kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[

View File

@@ -0,0 +1,28 @@
from transformers import TrainerCallback
class KDAlphaSchedulerCallback(TrainerCallback):
"""Callback to for scheduling KD alpha during training."""
def on_epoch_begin(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if int(state.epoch) == 0:
state.kd_alpha = args.kd_alpha
state.kd_ce_alpha = args.kd_ce_alpha
elif int(state.epoch) == state.num_train_epochs - 1:
if args.kd_alpha_end is not None:
control.kd_alpha = args.kd_alpha_end
if args.kd_ce_alpha_end is not None:
control.kd_ce_alpha = args.kd_ce_alpha_end
else:
epoch_steps = state.num_train_epochs - 1
scale = int(state.epoch) / epoch_steps
if args.kd_alpha_end is not None:
control.kd_alpha = (
args.kd_alpha + (args.kd_alpha_end - args.kd_alpha) * scale
)
if args.kd_ce_alpha_end is not None:
control.kd_ce_alpha = (
args.kd_ce_alpha + (args.kd_ce_alpha_end - args.kd_ce_alpha) * scale
)

View File

@@ -62,10 +62,16 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
Transform logprobs to target format for KD training
"""
logprobs = sample.pop(self.logprobs_field)
if "target_logprobs" in sample.keys() and "target_token_ids" in sample.keys():
logprobs = sample.pop("target_logprobs")
token_ids = sample.pop("target_token_ids")
else:
logprobs = sample.pop(self.logprobs_field)
token_ids = [None] * len(logprobs)
target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len
target_padding_len = input_seq_len - target_seq_len
# get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [
len(logprobs[i])
@@ -82,11 +88,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_token_ids = []
target_mask = []
if input_padding_len < 0:
if target_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
target_padding_len = 0
# target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k
@@ -98,33 +104,37 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer
shift = 0
for _ in range(shift, input_padding_len):
for _ in range(shift, target_padding_len):
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
for position in range(input_padding_len, input_seq_len):
for position in range(target_padding_len, input_seq_len):
if sample["labels"][position] == -100:
target_mask.append([0] * top_k)
else:
target_mask.append([1] * top_k)
for _, token_pos_logprobs in enumerate(logprobs):
for token_pos_logprobs, token_pos_token_ids in zip(logprobs, token_ids):
# Initialize collections for logprobs and token_ids
position_logprobs = []
position_token_ids = []
# Process each token probability entry
for entry in token_pos_logprobs:
# Extract logprob value
logprob = entry["logprob"]
if token_pos_token_ids is None:
for entry in token_pos_logprobs:
# Extract logprob value
logprob = entry["logprob"]
# Parse token_id from the "token_id:###" format
token_id = int(entry["token"].split(":")[1])
# Parse token_id from the "token_id:###" format
token_id = int(entry["token"].split(":")[1])
# Append to our collections
position_logprobs.append(logprob)
position_token_ids.append(token_id)
# Append to our collections
position_logprobs.append(logprob)
position_token_ids.append(token_id)
else:
position_logprobs = token_pos_logprobs
position_token_ids = token_pos_token_ids
# Convert to a tensor for easier manipulation
position_logprobs_tensor = torch.tensor(
@@ -143,6 +153,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
teacher_probs_t2 = teacher_probs_t1**exponent
else:
teacher_probs_t2 = teacher_probs_t1
# Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True

View File

@@ -16,17 +16,35 @@
KD trainer
"""
from transformers import TrainerControl
from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
class AxolotlKDTrainerControl(TrainerControl):
kd_alpha: float = 1.0
kd_ce_alpha: float = 0.0
def state(self) -> dict:
state_val = super().state()
state_val["args"]["kd_alpha"] = self.kd_alpha
state_val["args"]["kd_ce_alpha"] = self.kd_ce_alpha
class AxolotlKDTrainer(AxolotlTrainer):
"""
Custom trainer subclass for Knowledge Distillation (KD)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.kd_alpha = self.args.kd_alpha
self.kd_ce_alpha = self.args.kd_ce_alpha
self.control = AxolotlKDTrainerControl()
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()
columns_to_add = []
@@ -95,9 +113,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
)
if self.args.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
if self.kd_ce_alpha > 0:
loss = self.kd_ce_alpha * outputs["loss"] + self.kd_alpha * loss_kd
else:
loss = loss_kd
# Save past state if it exists

View File

@@ -17,7 +17,7 @@ Module for handling Spectrum input arguments.
"""
from typing import Optional
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
class SpectrumArgs(BaseModel):
@@ -27,3 +27,20 @@ class SpectrumArgs(BaseModel):
spectrum_top_fraction: Optional[float] = 0.5
spectrum_model_name: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_fsdp_use_orig_params(cls, data):
if (
data.get("fsdp")
and data.get("fsdp_config")
and not data["fsdp_config"].get("use_orig_params")
and data.get("plugins")
and any("SpectrumPlugin" in plugin for plugin in data["plugins"])
):
# would otherwise raise
# ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False`
raise ValueError(
"FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set"
)
return data

View File

@@ -7,7 +7,7 @@ import signal
import sys
import weakref
from pathlib import Path
from typing import Any
from typing import Any, Dict
import torch
import transformers.modelcard
@@ -20,7 +20,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
@@ -382,21 +382,23 @@ def handle_untrained_tokens_fix(
if not cfg.fix_untrained_tokens:
return
is_ds_zero3: bool = False
if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
is_ds_zero3 = True
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
fix_kwargs: Dict[str, Any] = {}
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
fix_kwargs["token_ids_to_fix"] = cfg.fix_untrained_tokens
if "is_ds_zero3" in sig.parameters:
fix_kwargs["is_ds_zero3"] = is_ds_zero3
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
if cfg.local_rank == 0:
model.save_pretrained(

View File

@@ -813,6 +813,15 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
# TODO if using deepspeed and it's a file, save deepspeed config too
if args.deepspeed and os.path.isfile(args.deepspeed):
LOG.info(f"DeepSpeed config has been saved to the WandB run.")
artifact = wandb.Artifact(
f"deepspeed-{wandb.run.id}", type="deepspeed-config"
)
artifact.add_file(args.deepspeed)
wandb.log_artifact(artifact)
wandb.save(args.deepspeed)
return control

View File

@@ -173,10 +173,16 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
]
out_features[i][feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in features_ if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
try:
arrays = [
np.array(item[feature])
for item in features_
if feature in item
]
if arrays[0].dtype != "object":
out_features[i][feature] = np.concatenate(arrays)
except ValueError:
pass
return super().__call__(out_features, return_tensors=return_tensors)

View File

@@ -72,7 +72,6 @@ class CustomSupportedOptimizers(str, Enum):
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
lion_pytorch = "lion_pytorch" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name
@@ -729,7 +728,7 @@ class AxolotlInputConfig(
default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"},
)
dataset_processes: Optional[int] = Field(default=os.cpu_count())
dataset_processes: Optional[int] = Field(default=min(32, os.cpu_count())) # type: ignore[type-var]
dataset_exact_deduplication: Optional[bool] = None
dataset_keep_in_memory: Optional[bool] = None
dataloader_pin_memory: Optional[bool] = None
@@ -780,9 +779,9 @@ class AxolotlInputConfig(
# torch_dtype: Optional[torch.dtype]
gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
default=False
)
gradient_checkpointing: Optional[
Union[Literal["unsloth", "offload"], bool]
] = Field(default=False)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
unfrozen_parameters: Optional[List[str]] = None
@@ -857,6 +856,7 @@ class AxolotlInputConfig(
special_tokens: Optional[SpecialTokensConfig] = None
tokens: Optional[List[str]] = None
added_tokens_overrides: Optional[Dict[int, str]] = None
torch_compile: Optional[Union[Literal["auto"], bool]] = None
torch_compile_backend: Optional[str] = None
@@ -1155,6 +1155,15 @@ class AxolotlInputConfig(
raise ValueError("gradient_checkpointing is not supported for MPT models")
return self
@model_validator(mode="after")
def check_offload_grad_checkpointing(self):
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
LOG.warning(
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
)
self.gradient_checkpointing = "offload"
return self
@model_validator(mode="after")
def check_better_transformers(self):
if self.flash_optimum is True:

View File

@@ -1,7 +1,8 @@
"""
GRPO specific configuration args
"""
from typing import List, Optional
from typing import Optional
from pydantic import BaseModel, Field
@@ -11,7 +12,10 @@ class TRLConfig(BaseModel):
Input args for TRL.
"""
beta: Optional[float] = None
beta: Optional[float] = Field(
default=None,
json_schema_extra={"description": "Beta for RL training"},
)
max_completion_length: Optional[int] = Field(
default=None,
json_schema_extra={
@@ -20,17 +24,68 @@ class TRLConfig(BaseModel):
)
# GRPO specific args
use_vllm: Optional[bool] = False
vllm_device: Optional[str] = "auto"
vllm_gpu_memory_utilization: Optional[float] = 0.9
vllm_max_model_len: Optional[int] = None
vllm_dtype: Optional[str] = "auto"
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
use_vllm: Optional[bool] = Field(
default=False,
json_schema_extra={"description": "Whether to use VLLM for RL training"},
)
vllm_device: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Device to use for VLLM"},
)
vllm_gpu_memory_utilization: Optional[float] = Field(
default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"},
)
vllm_dtype: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Data type for VLLM"},
)
vllm_max_model_len: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the model context for VLLM"
},
)
reward_funcs: Optional[List[str]] = None
reward_weights: Optional[List[float]] = None
num_generations: Optional[int] = None
log_completions: Optional[bool] = False
sync_ref_model: Optional[bool] = False
ref_model_mixup_alpha: Optional[float] = 0.9
ref_model_sync_steps: Optional[int] = 64
reward_funcs: Optional[list[str]] = Field(
default=None,
json_schema_extra={"description": "List of reward functions to load"},
)
reward_weights: Optional[list[float]] = Field(
default=None,
json_schema_extra={
"description": "Weights for each reward function. Must match the number of reward functions."
},
)
num_generations: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
},
)
log_completions: Optional[bool] = Field(
default=False,
json_schema_extra={"description": "Whether to log completions"},
)
sync_ref_model: Optional[bool] = Field(
default=False,
json_schema_extra={
"description": (
"Whether to sync the reference model every `ref_model_sync_steps` "
"steps, using the `ref_model_mixup_alpha` parameter."
)
},
)
ref_model_mixup_alpha: Optional[float] = Field(
default=0.9,
json_schema_extra={
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
},
)
ref_model_sync_steps: Optional[int] = Field(
default=64,
json_schema_extra={
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
},
)

View File

@@ -79,7 +79,7 @@ def is_main_process():
def is_local_main_process():
return PartialState().is_main_process
return PartialState().is_local_main_process
def get_world_size():

View File

@@ -4,7 +4,7 @@ from axolotl.utils.gradient_checkpointing.unsloth import (
)
def hf_grad_checkpoint_unsloth_wrapper(
def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply(

View File

@@ -24,7 +24,6 @@ from peft import (
PeftModelForCausalLM,
prepare_model_for_kbit_training,
)
from peft.tuners.lora import QuantLinear
from torch import nn
from transformers import ( # noqa: F401
AddedToken,
@@ -57,8 +56,14 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
from axolotl.utils.distributed import (
barrier,
get_device_count,
get_device_type,
is_local_main_process,
zero_only,
)
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
@@ -165,7 +170,95 @@ def load_model_config(cfg):
return model_config
def modify_tokenizer_files(
tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
) -> str:
"""
Modify tokenizer files to replace added_tokens strings, save to output directory, and return the path to the modified tokenizer.
This only works with reserved tokens that were added to the tokenizer, not tokens already part of the vocab.
Args:
tokenizer_path: Path or name of the original tokenizer
token_mappings: Dict mapping {token_id (int): new_token_string}
output_dir: Directory to save the modified tokenizer
Returns:
Path to the modified tokenizer directory
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
"""
import json
# Create the tokenizer directory in output_dir if it doesn't exist
tokenizer_dir = os.path.join(output_dir, "tokenizer")
os.makedirs(tokenizer_dir, exist_ok=True)
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
# Load the tokenizer
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
# Save the tokenizer to the output directory
temp_tokenizer.save_pretrained(tokenizer_dir)
# Get the token IDs and map them to their new values
token_id_mappings = {
int(token_id): new_value for token_id, new_value in token_mappings.items()
}
# 1. Update tokenizer_config.json - added_tokens_decoder
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
config_data = json.load(f)
# Update added_tokens_decoder
if "added_tokens_decoder" in config_data:
for token_id, new_value in token_id_mappings.items():
token_id_str = str(token_id)
if token_id_str in config_data["added_tokens_decoder"]:
config_data["added_tokens_decoder"][token_id_str][
"content"
] = new_value
else:
raise ValueError(
f"Token ID {token_id_str} not found in added_tokens_decoder"
)
# Write the updated config back
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
# 2. Update tokenizer.json - added_tokens
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
if os.path.exists(tokenizer_path):
with open(tokenizer_path, "r", encoding="utf-8") as f:
tokenizer_data = json.load(f)
# Update added_tokens
if "added_tokens" in tokenizer_data:
for token_id, new_value in token_id_mappings.items():
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
if token_entry["id"] == token_id:
tokenizer_data["added_tokens"][i]["content"] = new_value
break
else:
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
raise ValueError(
f"Token ID {token_id} not found in added_tokens"
)
# Write the updated tokenizer data back
with open(tokenizer_path, "w", encoding="utf-8") as f:
json.dump(tokenizer_data, f, indent=2)
barrier()
return tokenizer_dir
def load_tokenizer(cfg):
"""Load and configure the tokenizer based on the provided config."""
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
use_fast = True # this is the default
@@ -180,8 +273,18 @@ def load_tokenizer(cfg):
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
# Set base tokenizer path
tokenizer_path = cfg.tokenizer_config
# Apply token string overrides if specified
if cfg.added_tokens_overrides:
# Modify tokenizer files and get path to modified tokenizer
tokenizer_path = modify_tokenizer_files(
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
)
tokenizer = tokenizer_cls.from_pretrained(
cfg.tokenizer_config,
tokenizer_path,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
@@ -389,8 +492,8 @@ class ModelLoader:
patch_fa_peft_integration()
if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
if self.cfg.flash_attention:
self.patch_attention()
@@ -1256,7 +1359,7 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
lora_module_names = set()
for name, module in model.named_modules():
if (

View File

@@ -14,7 +14,7 @@
h1 {
font-family: var(--font-title);
font-weight: 400;
font-size: 6rem;
font-size: 5rem;
line-height: 1.1;
letter-spacing: -0.05em;
font-feature-settings: "ss01" on;

View File

@@ -25,8 +25,8 @@ def fixture_cfg():
"optimizer": "adamw_torch_fused",
"sequence_len": 2048,
"rl": True,
"adam_beta1": 0.998,
"adam_beta2": 0.9,
"adam_beta1": 0.91,
"adam_beta2": 0.998,
"adam_epsilon": 0.00001,
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
@@ -60,8 +60,8 @@ class TestHFRLTrainerBuilder:
def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100)
assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9
assert training_arguments.adam_beta1 == 0.91
assert training_arguments.adam_beta2 == 0.998
assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True

View File

@@ -69,6 +69,51 @@ class TestCutCrossEntropyIntegration:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
# pylint: disable=redefined-outer-name
def test_qwen2_w_cce(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"plugins": [
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
],
"cut_cross_entropy": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"output_dir": temp_dir,
"lr_scheduler": "cosine",
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
}
)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, dataset_meta=dataset_meta)
else:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.parametrize(
"attention_type",
[

View File

@@ -750,3 +750,66 @@ class TestMultiGPULlama:
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
"use_tensorboard": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
)

View File

@@ -66,6 +66,54 @@ class TestLlama:
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens_already_trained(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{

View File

@@ -1,6 +1,7 @@
"""
Test cases for the tokenizer loading
"""
import unittest
import pytest
@@ -9,7 +10,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer
class TestTokenizers(unittest.TestCase):
class TestTokenizers:
"""
test class for the load_tokenizer fn
"""
@@ -75,12 +76,48 @@ class TestTokenizers(unittest.TestCase):
}
)
tokenizer = load_tokenizer(cfg)
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
self.assertEqual(len(tokenizer), 32001)
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
assert len(tokenizer) == 32001
# ensure reloading the tokenizer again from cfg results in same vocab length
tokenizer = load_tokenizer(cfg)
self.assertEqual(len(tokenizer), 32001)
assert len(tokenizer) == 32001
def test_added_tokens_overrides(self, temp_dir):
cfg = DictDefault(
{
# use with tokenizer that has reserved_tokens in added_tokens
"tokenizer_config": "NousResearch/Llama-3.2-1B",
"added_tokens_overrides": {
128041: "RANDOM_OVERRIDE_1",
128042: "RANDOM_OVERRIDE_2",
},
"output_dir": temp_dir,
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer.encode("RANDOM_OVERRIDE_1", add_special_tokens=False) == [
128041
]
assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [
128042
]
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
cfg = DictDefault(
{
# use with tokenizer that has reserved_tokens in added_tokens
"tokenizer_config": "NousResearch/Llama-3.2-1B",
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
"output_dir": temp_dir,
}
)
with pytest.raises(
ValueError, match=r".*Token ID 1000000 not found in added_tokens.*"
):
load_tokenizer(cfg)
if __name__ == "__main__":