Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
76bb09784d fix import 2025-03-05 14:05:27 -05:00
Wing Lian
0542c7dd56 add muon optimizer
optimizer_cls_and_kwargs is on trainer_kwargs
only add adamw_kwargs if they're non-null
fix mocks
better handling of override and check the optimizer
unwrap optimizer
2025-03-05 10:47:22 -05:00
36 changed files with 81 additions and 590 deletions

View File

@@ -88,11 +88,6 @@ jobs:
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -80,11 +80,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -14,7 +14,7 @@ COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \ RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \ RUN apt install --yes --no-install-recommends openssh-server tmux && \
mkdir -p ~/.ssh && \ mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \ chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \ printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \

View File

@@ -154,6 +154,8 @@ datasets:
content: value content: value
# ... # ...
message_property_mappings:
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is: # Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles: roles:
user: ["human", "user"] user: ["human", "user"]
@@ -554,13 +556,6 @@ special_tokens:
# Add extra tokens. # Add extra tokens.
tokens: tokens:
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
# Can be checked if they exist in tokenizer.json added_tokens.
added_tokens_overrides: # Dict[int, str]
# 128041: "<|im_start|>"
# 128042: "<|im_end|>"
# FSDP # FSDP
fsdp: fsdp:
fsdp_config: fsdp_config:

View File

@@ -74,10 +74,6 @@ datasets:
train_on_eos: train_on_eos:
``` ```
::: {.callout-tip}
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
:::
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages. 2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml ```yaml

View File

@@ -52,7 +52,3 @@ description: Frequently asked questions
**Q: The EOS/EOT token is incorrectly being masked or not being masked.** **Q: The EOS/EOT token is incorrectly being masked or not being masked.**
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template. > A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.

View File

@@ -28,17 +28,6 @@ val_set_size: 0.1
eval_steps: 100 eval_steps: 100
``` ```
Bradley-Terry chat templates expect single-turn conversations in the following format:
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
### Process Reward Models (PRM) ### Process Reward Models (PRM)
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning. Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
@@ -56,5 +45,3 @@ datasets:
val_set_size: 0.1 val_set_size: 0.1
eval_steps: 100 eval_steps: 100
``` ```
Please see [stepwise_supervised](dataset-formats/stepwise_supervised.qmd) for more details on the dataset format.

View File

@@ -3,7 +3,6 @@ title: "RLHF (Beta)"
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback." description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
back-to-top-navigation: true back-to-top-navigation: true
toc: true toc: true
toc-expand: 2
toc-depth: 4 toc-depth: 4
--- ---
@@ -529,7 +528,6 @@ trl:
vllm_gpu_memory_utilization: 0.15 vllm_gpu_memory_utilization: 0.15
num_generations: 4 num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}' reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0]
datasets: datasets:
- path: openai/gsm8k - path: openai/gsm8k
name: main name: main
@@ -538,8 +536,6 @@ datasets:
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function). To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
### Using local dataset files ### Using local dataset files
```yaml ```yaml

View File

@@ -62,5 +62,5 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0 torchao==0.7.0
schedulefree==1.3.0 schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.6 axolotl-contribs-lgpl==0.0.3
axolotl-contribs-mit==0.0.3 axolotl-contribs-mit==0.0.3

View File

@@ -24,5 +24,5 @@ if cce_spec:
print( print(
UNINSTALL_PREFIX UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"' + 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
) )

View File

@@ -113,7 +113,7 @@ class ModalCloud(Cloud):
[ [
# Random id for cache busting of branch commits # Random id for cache busting of branch commits
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311 f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch} && git pull", f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
] ]
) )
@@ -270,7 +270,6 @@ def _preprocess(config_yaml: str, volumes=None):
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs): def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out: with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml) f_out.write(config_yaml)
run_folder = "/workspace/mounts" run_folder = "/workspace/mounts"
@@ -289,7 +288,6 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
def _lm_eval(config_yaml: str, volumes=None): def _lm_eval(config_yaml: str, volumes=None):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out: with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml) f_out.write(config_yaml)
run_folder = "/workspace/mounts" run_folder = "/workspace/mounts"

View File

@@ -1,7 +1,6 @@
"""CLI to run training on a model.""" """CLI to run training on a model."""
import logging import logging
import os
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -35,8 +34,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
""" """
print_axolotl_text_art() print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0: check_user_token()
check_user_token()
if cfg.rl: if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -43,7 +43,7 @@ class TokenizedChatDataset(Dataset):
process_or_cpu_count: int = ( process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment] process_count or os.cpu_count() # type: ignore[assignment]
) )
num_proc = min(32, process_or_cpu_count) num_proc = min(64, process_or_cpu_count)
features = data.features.keys() features = data.features.keys()
tokenized_data = data.map( tokenized_data = data.map(
map_fn, map_fn,

View File

@@ -751,12 +751,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.kd_ce_alpha is not None: if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_ce_alpha_end is not None:
training_arguments_kwargs["kd_ce_alpha_end"] = self.cfg.kd_ce_alpha_end
if self.cfg.kd_alpha is not None: if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_alpha_end is not None:
training_arguments_kwargs["kd_alpha_end"] = self.cfg.kd_alpha_end
if self.cfg.kd_temperature is not None: if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None: if self.cfg.kd_zscore_base_temp is not None:

View File

@@ -17,7 +17,7 @@ Run the following command to install `cut_cross_entropy[transformers]` if you do
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh
# if you are not in dev environment # if you are not in dev environment
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c" pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
``` ```
## Usage ## Usage

View File

@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
_CCE_INSTALL_MESSAGE = ( _CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using " "Please install cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`' '`pip install "cut-cross-entropy[transformers]==24.11.4"`'
) )

View File

@@ -34,12 +34,3 @@ class KDPlugin(BasePlugin):
return AxolotlKDTrainer return AxolotlKDTrainer
return None return None
def add_callbacks_post_trainer(self, cfg, trainer):
callbacks = []
if cfg.kd_trainer:
from .callbacks import KDAlphaSchedulerCallback
callbacks.append(KDAlphaSchedulerCallback())
return callbacks

View File

@@ -30,8 +30,6 @@ class KDArgs(BaseModel):
float float
] = None # loss coefficient for cross-entropy loss during KD ] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_ce_alpha_end: Optional[float] = None # end value for kd_ce_alpha
kd_alpha_end: Optional[float] = None # end value for kd_alpha
kd_temperature: Optional[float] = None # temperature for sampling during KD kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[ kd_top_k_before_softmax: Optional[

View File

@@ -1,28 +0,0 @@
from transformers import TrainerCallback
class KDAlphaSchedulerCallback(TrainerCallback):
"""Callback to for scheduling KD alpha during training."""
def on_epoch_begin(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if int(state.epoch) == 0:
state.kd_alpha = args.kd_alpha
state.kd_ce_alpha = args.kd_ce_alpha
elif int(state.epoch) == state.num_train_epochs - 1:
if args.kd_alpha_end is not None:
control.kd_alpha = args.kd_alpha_end
if args.kd_ce_alpha_end is not None:
control.kd_ce_alpha = args.kd_ce_alpha_end
else:
epoch_steps = state.num_train_epochs - 1
scale = int(state.epoch) / epoch_steps
if args.kd_alpha_end is not None:
control.kd_alpha = (
args.kd_alpha + (args.kd_alpha_end - args.kd_alpha) * scale
)
if args.kd_ce_alpha_end is not None:
control.kd_ce_alpha = (
args.kd_ce_alpha + (args.kd_ce_alpha_end - args.kd_ce_alpha) * scale
)

View File

@@ -62,16 +62,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
Transform logprobs to target format for KD training Transform logprobs to target format for KD training
""" """
if "target_logprobs" in sample.keys() and "target_token_ids" in sample.keys(): logprobs = sample.pop(self.logprobs_field)
logprobs = sample.pop("target_logprobs")
token_ids = sample.pop("target_token_ids")
else:
logprobs = sample.pop(self.logprobs_field)
token_ids = [None] * len(logprobs)
target_seq_len = len(logprobs) target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"]) input_seq_len = len(sample["input_ids"])
target_padding_len = input_seq_len - target_seq_len input_padding_len = input_seq_len - target_seq_len
# get non-zero top-k (prune None logprobs from vllm data step) # get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [ top_k_vals = [
len(logprobs[i]) len(logprobs[i])
@@ -88,11 +82,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_token_ids = [] target_token_ids = []
target_mask = [] target_mask = []
if target_padding_len < 0: if input_padding_len < 0:
# logprobs is longer than target_seq_len, # logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs # so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len] logprobs = logprobs[:-input_seq_len]
target_padding_len = 0 input_padding_len = 0
# target_seq_len = input_seq_len # target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k # truncate the second dimension of the logprobs to top_k
@@ -104,37 +98,33 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# for causal models, if we start the range at 1, then we don't need to shift in the trainer # for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer # otherwise, we need to shift in the trainer
shift = 0 shift = 0
for _ in range(shift, target_padding_len): for _ in range(shift, input_padding_len):
target_logprobs.append([-float("inf")] * top_k) target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k))) target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k) target_mask.append([0] * top_k)
for position in range(target_padding_len, input_seq_len): for position in range(input_padding_len, input_seq_len):
if sample["labels"][position] == -100: if sample["labels"][position] == -100:
target_mask.append([0] * top_k) target_mask.append([0] * top_k)
else: else:
target_mask.append([1] * top_k) target_mask.append([1] * top_k)
for token_pos_logprobs, token_pos_token_ids in zip(logprobs, token_ids): for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids # Initialize collections for logprobs and token_ids
position_logprobs = [] position_logprobs = []
position_token_ids = [] position_token_ids = []
# Process each token probability entry # Process each token probability entry
if token_pos_token_ids is None: for entry in token_pos_logprobs:
for entry in token_pos_logprobs: # Extract logprob value
# Extract logprob value logprob = entry["logprob"]
logprob = entry["logprob"]
# Parse token_id from the "token_id:###" format # Parse token_id from the "token_id:###" format
token_id = int(entry["token"].split(":")[1]) token_id = int(entry["token"].split(":")[1])
# Append to our collections # Append to our collections
position_logprobs.append(logprob) position_logprobs.append(logprob)
position_token_ids.append(token_id) position_token_ids.append(token_id)
else:
position_logprobs = token_pos_logprobs
position_token_ids = token_pos_token_ids
# Convert to a tensor for easier manipulation # Convert to a tensor for easier manipulation
position_logprobs_tensor = torch.tensor( position_logprobs_tensor = torch.tensor(
@@ -153,7 +143,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
teacher_probs_t2 = teacher_probs_t1**exponent teacher_probs_t2 = teacher_probs_t1**exponent
else: else:
teacher_probs_t2 = teacher_probs_t1 teacher_probs_t2 = teacher_probs_t1
# Re-normalize # Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum( teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True dim=0, keepdim=True

View File

@@ -16,35 +16,17 @@
KD trainer KD trainer
""" """
from transformers import TrainerControl
from axolotl.core.trainers.base import AxolotlTrainer from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import loss as topk_kd_loss from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
class AxolotlKDTrainerControl(TrainerControl):
kd_alpha: float = 1.0
kd_ce_alpha: float = 0.0
def state(self) -> dict:
state_val = super().state()
state_val["args"]["kd_alpha"] = self.kd_alpha
state_val["args"]["kd_ce_alpha"] = self.kd_ce_alpha
class AxolotlKDTrainer(AxolotlTrainer): class AxolotlKDTrainer(AxolotlTrainer):
""" """
Custom trainer subclass for Knowledge Distillation (KD) Custom trainer subclass for Knowledge Distillation (KD)
""" """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.kd_alpha = self.args.kd_alpha
self.kd_ce_alpha = self.args.kd_ce_alpha
self.control = AxolotlKDTrainerControl()
def _set_signature_columns_if_needed(self): def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed() super()._set_signature_columns_if_needed()
columns_to_add = [] columns_to_add = []
@@ -113,8 +95,9 @@ class AxolotlKDTrainer(AxolotlTrainer):
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0, top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
) )
if self.kd_ce_alpha > 0: if self.args.kd_ce_alpha > 0:
loss = self.kd_ce_alpha * outputs["loss"] + self.kd_alpha * loss_kd kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else: else:
loss = loss_kd loss = loss_kd
# Save past state if it exists # Save past state if it exists

View File

@@ -17,7 +17,7 @@ Module for handling Spectrum input arguments.
""" """
from typing import Optional from typing import Optional
from pydantic import BaseModel, model_validator from pydantic import BaseModel
class SpectrumArgs(BaseModel): class SpectrumArgs(BaseModel):
@@ -27,20 +27,3 @@ class SpectrumArgs(BaseModel):
spectrum_top_fraction: Optional[float] = 0.5 spectrum_top_fraction: Optional[float] = 0.5
spectrum_model_name: Optional[str] = None spectrum_model_name: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_fsdp_use_orig_params(cls, data):
if (
data.get("fsdp")
and data.get("fsdp_config")
and not data["fsdp_config"].get("use_orig_params")
and data.get("plugins")
and any("SpectrumPlugin" in plugin for plugin in data["plugins"])
):
# would otherwise raise
# ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False`
raise ValueError(
"FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set"
)
return data

View File

@@ -7,7 +7,7 @@ import signal
import sys import sys
import weakref import weakref
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any
import torch import torch
import transformers.modelcard import transformers.modelcard
@@ -20,7 +20,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer from transformers.trainer import Trainer
from axolotl.common.datasets import TrainDatasetMeta from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
fix_untrained_tokens, fix_untrained_tokens,
) )
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
@@ -382,23 +382,21 @@ def handle_untrained_tokens_fix(
if not cfg.fix_untrained_tokens: if not cfg.fix_untrained_tokens:
return return
is_ds_zero3: bool = False
if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
is_ds_zero3 = True
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args # Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens) sig = inspect.signature(fix_untrained_tokens)
fix_kwargs: Dict[str, Any] = {}
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list # If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance( if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list cfg.fix_untrained_tokens, list
): ):
fix_kwargs["token_ids_to_fix"] = cfg.fix_untrained_tokens fix_untrained_tokens(
if "is_ds_zero3" in sig.parameters: model,
fix_kwargs["is_ds_zero3"] = is_ds_zero3 tokenizer,
train_dataset,
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs) token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0: if cfg.local_rank == 0:
model.save_pretrained( model.save_pretrained(

View File

@@ -813,15 +813,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
) )
except (FileNotFoundError, ConnectionError) as err: except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}") LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
# TODO if using deepspeed and it's a file, save deepspeed config too
if args.deepspeed and os.path.isfile(args.deepspeed):
LOG.info(f"DeepSpeed config has been saved to the WandB run.")
artifact = wandb.Artifact(
f"deepspeed-{wandb.run.id}", type="deepspeed-config"
)
artifact.add_file(args.deepspeed)
wandb.log_artifact(artifact)
wandb.save(args.deepspeed)
return control return control

View File

@@ -173,16 +173,10 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
] ]
out_features[i][feature] = np.concatenate(arrays) out_features[i][feature] = np.concatenate(arrays)
else: else:
try: arrays = [
arrays = [ np.array(item[feature]) for item in features_ if feature in item
np.array(item[feature]) ]
for item in features_ out_features[i][feature] = np.concatenate(arrays)
if feature in item
]
if arrays[0].dtype != "object":
out_features[i][feature] = np.concatenate(arrays)
except ValueError:
pass
return super().__call__(out_features, return_tensors=return_tensors) return super().__call__(out_features, return_tensors=return_tensors)

View File

@@ -72,6 +72,7 @@ class CustomSupportedOptimizers(str, Enum):
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
lion_pytorch = "lion_pytorch" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name muon = "muon" # pylint: disable=invalid-name
@@ -728,7 +729,7 @@ class AxolotlInputConfig(
default=None, default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"}, json_schema_extra={"description": "streaming dataset to use for pretraining"},
) )
dataset_processes: Optional[int] = Field(default=min(32, os.cpu_count())) # type: ignore[type-var] dataset_processes: Optional[int] = Field(default=os.cpu_count())
dataset_exact_deduplication: Optional[bool] = None dataset_exact_deduplication: Optional[bool] = None
dataset_keep_in_memory: Optional[bool] = None dataset_keep_in_memory: Optional[bool] = None
dataloader_pin_memory: Optional[bool] = None dataloader_pin_memory: Optional[bool] = None
@@ -779,9 +780,9 @@ class AxolotlInputConfig(
# torch_dtype: Optional[torch.dtype] # torch_dtype: Optional[torch.dtype]
gradient_checkpointing: Optional[ gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
Union[Literal["unsloth", "offload"], bool] default=False
] = Field(default=False) )
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
unfrozen_parameters: Optional[List[str]] = None unfrozen_parameters: Optional[List[str]] = None
@@ -856,7 +857,6 @@ class AxolotlInputConfig(
special_tokens: Optional[SpecialTokensConfig] = None special_tokens: Optional[SpecialTokensConfig] = None
tokens: Optional[List[str]] = None tokens: Optional[List[str]] = None
added_tokens_overrides: Optional[Dict[int, str]] = None
torch_compile: Optional[Union[Literal["auto"], bool]] = None torch_compile: Optional[Union[Literal["auto"], bool]] = None
torch_compile_backend: Optional[str] = None torch_compile_backend: Optional[str] = None
@@ -1155,15 +1155,6 @@ class AxolotlInputConfig(
raise ValueError("gradient_checkpointing is not supported for MPT models") raise ValueError("gradient_checkpointing is not supported for MPT models")
return self return self
@model_validator(mode="after")
def check_offload_grad_checkpointing(self):
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
LOG.warning(
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
)
self.gradient_checkpointing = "offload"
return self
@model_validator(mode="after") @model_validator(mode="after")
def check_better_transformers(self): def check_better_transformers(self):
if self.flash_optimum is True: if self.flash_optimum is True:

View File

@@ -1,8 +1,7 @@
""" """
GRPO specific configuration args GRPO specific configuration args
""" """
from typing import List, Optional
from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -12,10 +11,7 @@ class TRLConfig(BaseModel):
Input args for TRL. Input args for TRL.
""" """
beta: Optional[float] = Field( beta: Optional[float] = None
default=None,
json_schema_extra={"description": "Beta for RL training"},
)
max_completion_length: Optional[int] = Field( max_completion_length: Optional[int] = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
@@ -24,68 +20,17 @@ class TRLConfig(BaseModel):
) )
# GRPO specific args # GRPO specific args
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22 use_vllm: Optional[bool] = False
use_vllm: Optional[bool] = Field( vllm_device: Optional[str] = "auto"
default=False, vllm_gpu_memory_utilization: Optional[float] = 0.9
json_schema_extra={"description": "Whether to use VLLM for RL training"}, vllm_max_model_len: Optional[int] = None
) vllm_dtype: Optional[str] = "auto"
vllm_device: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Device to use for VLLM"},
)
vllm_gpu_memory_utilization: Optional[float] = Field(
default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"},
)
vllm_dtype: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Data type for VLLM"},
)
vllm_max_model_len: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the model context for VLLM"
},
)
reward_funcs: Optional[list[str]] = Field( reward_funcs: Optional[List[str]] = None
default=None, reward_weights: Optional[List[float]] = None
json_schema_extra={"description": "List of reward functions to load"}, num_generations: Optional[int] = None
) log_completions: Optional[bool] = False
reward_weights: Optional[list[float]] = Field(
default=None, sync_ref_model: Optional[bool] = False
json_schema_extra={ ref_model_mixup_alpha: Optional[float] = 0.9
"description": "Weights for each reward function. Must match the number of reward functions." ref_model_sync_steps: Optional[int] = 64
},
)
num_generations: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
},
)
log_completions: Optional[bool] = Field(
default=False,
json_schema_extra={"description": "Whether to log completions"},
)
sync_ref_model: Optional[bool] = Field(
default=False,
json_schema_extra={
"description": (
"Whether to sync the reference model every `ref_model_sync_steps` "
"steps, using the `ref_model_mixup_alpha` parameter."
)
},
)
ref_model_mixup_alpha: Optional[float] = Field(
default=0.9,
json_schema_extra={
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
},
)
ref_model_sync_steps: Optional[int] = Field(
default=64,
json_schema_extra={
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
},
)

View File

@@ -79,7 +79,7 @@ def is_main_process():
def is_local_main_process(): def is_local_main_process():
return PartialState().is_local_main_process return PartialState().is_main_process
def get_world_size(): def get_world_size():

View File

@@ -4,7 +4,7 @@ from axolotl.utils.gradient_checkpointing.unsloth import (
) )
def hf_grad_checkpoint_offload_wrapper( def hf_grad_checkpoint_unsloth_wrapper(
decoder_layer, *args, use_reentrant=None decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply( return Unsloth_Offloaded_Gradient_Checkpointer.apply(

View File

@@ -24,6 +24,7 @@ from peft import (
PeftModelForCausalLM, PeftModelForCausalLM,
prepare_model_for_kbit_training, prepare_model_for_kbit_training,
) )
from peft.tuners.lora import QuantLinear
from torch import nn from torch import nn
from transformers import ( # noqa: F401 from transformers import ( # noqa: F401
AddedToken, AddedToken,
@@ -56,14 +57,8 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import ( from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
barrier, from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
get_device_count,
get_device_type,
is_local_main_process,
zero_only,
)
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
@@ -170,95 +165,7 @@ def load_model_config(cfg):
return model_config return model_config
def modify_tokenizer_files(
tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
) -> str:
"""
Modify tokenizer files to replace added_tokens strings, save to output directory, and return the path to the modified tokenizer.
This only works with reserved tokens that were added to the tokenizer, not tokens already part of the vocab.
Args:
tokenizer_path: Path or name of the original tokenizer
token_mappings: Dict mapping {token_id (int): new_token_string}
output_dir: Directory to save the modified tokenizer
Returns:
Path to the modified tokenizer directory
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
"""
import json
# Create the tokenizer directory in output_dir if it doesn't exist
tokenizer_dir = os.path.join(output_dir, "tokenizer")
os.makedirs(tokenizer_dir, exist_ok=True)
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
# Load the tokenizer
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
# Save the tokenizer to the output directory
temp_tokenizer.save_pretrained(tokenizer_dir)
# Get the token IDs and map them to their new values
token_id_mappings = {
int(token_id): new_value for token_id, new_value in token_mappings.items()
}
# 1. Update tokenizer_config.json - added_tokens_decoder
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
config_data = json.load(f)
# Update added_tokens_decoder
if "added_tokens_decoder" in config_data:
for token_id, new_value in token_id_mappings.items():
token_id_str = str(token_id)
if token_id_str in config_data["added_tokens_decoder"]:
config_data["added_tokens_decoder"][token_id_str][
"content"
] = new_value
else:
raise ValueError(
f"Token ID {token_id_str} not found in added_tokens_decoder"
)
# Write the updated config back
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
# 2. Update tokenizer.json - added_tokens
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
if os.path.exists(tokenizer_path):
with open(tokenizer_path, "r", encoding="utf-8") as f:
tokenizer_data = json.load(f)
# Update added_tokens
if "added_tokens" in tokenizer_data:
for token_id, new_value in token_id_mappings.items():
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
if token_entry["id"] == token_id:
tokenizer_data["added_tokens"][i]["content"] = new_value
break
else:
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
raise ValueError(
f"Token ID {token_id} not found in added_tokens"
)
# Write the updated tokenizer data back
with open(tokenizer_path, "w", encoding="utf-8") as f:
json.dump(tokenizer_data, f, indent=2)
barrier()
return tokenizer_dir
def load_tokenizer(cfg): def load_tokenizer(cfg):
"""Load and configure the tokenizer based on the provided config."""
model_config = load_model_config(cfg) model_config = load_model_config(cfg)
tokenizer_kwargs = {} tokenizer_kwargs = {}
use_fast = True # this is the default use_fast = True # this is the default
@@ -273,18 +180,8 @@ def load_tokenizer(cfg):
if cfg.tokenizer_type: if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type) tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
# Set base tokenizer path
tokenizer_path = cfg.tokenizer_config
# Apply token string overrides if specified
if cfg.added_tokens_overrides:
# Modify tokenizer files and get path to modified tokenizer
tokenizer_path = modify_tokenizer_files(
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
)
tokenizer = tokenizer_cls.from_pretrained( tokenizer = tokenizer_cls.from_pretrained(
tokenizer_path, cfg.tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast, use_fast=use_fast,
**tokenizer_kwargs, **tokenizer_kwargs,
@@ -492,8 +389,8 @@ class ModelLoader:
patch_fa_peft_integration() patch_fa_peft_integration()
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]: if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if self.cfg.flash_attention: if self.cfg.flash_attention:
self.patch_attention() self.patch_attention()
@@ -1359,7 +1256,7 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model): def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear) cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
lora_module_names = set() lora_module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if ( if (

View File

@@ -14,7 +14,7 @@
h1 { h1 {
font-family: var(--font-title); font-family: var(--font-title);
font-weight: 400; font-weight: 400;
font-size: 5rem; font-size: 6rem;
line-height: 1.1; line-height: 1.1;
letter-spacing: -0.05em; letter-spacing: -0.05em;
font-feature-settings: "ss01" on; font-feature-settings: "ss01" on;

View File

@@ -25,8 +25,8 @@ def fixture_cfg():
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"sequence_len": 2048, "sequence_len": 2048,
"rl": True, "rl": True,
"adam_beta1": 0.91, "adam_beta1": 0.998,
"adam_beta2": 0.998, "adam_beta2": 0.9,
"adam_epsilon": 0.00001, "adam_epsilon": 0.00001,
"dataloader_num_workers": 1, "dataloader_num_workers": 1,
"dataloader_pin_memory": True, "dataloader_pin_memory": True,
@@ -60,8 +60,8 @@ class TestHFRLTrainerBuilder:
def test_build_training_arguments(self, cfg, model, tokenizer): def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFRLTrainerBuilder(cfg, model, tokenizer) builder = HFRLTrainerBuilder(cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100) training_arguments = builder.build_training_arguments(100)
assert training_arguments.adam_beta1 == 0.91 assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.998 assert training_arguments.adam_beta2 == 0.9
assert training_arguments.adam_epsilon == 0.00001 assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.dataloader_num_workers == 1 assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True assert training_arguments.dataloader_pin_memory is True

View File

@@ -69,51 +69,6 @@ class TestCutCrossEntropyIntegration:
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
# pylint: disable=redefined-outer-name
def test_qwen2_w_cce(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"plugins": [
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
],
"cut_cross_entropy": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"output_dir": temp_dir,
"lr_scheduler": "cosine",
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
}
)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, dataset_meta=dataset_meta)
else:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"attention_type", "attention_type",
[ [

View File

@@ -750,66 +750,3 @@ class TestMultiGPULlama:
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
) )
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
"use_tensorboard": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
)

View File

@@ -66,54 +66,6 @@ class TestLlama:
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens(self, temp_dir): def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens_already_trained(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {

View File

@@ -1,7 +1,6 @@
""" """
Test cases for the tokenizer loading Test cases for the tokenizer loading
""" """
import unittest import unittest
import pytest import pytest
@@ -10,7 +9,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer from axolotl.utils.models import load_tokenizer
class TestTokenizers: class TestTokenizers(unittest.TestCase):
""" """
test class for the load_tokenizer fn test class for the load_tokenizer fn
""" """
@@ -76,48 +75,12 @@ class TestTokenizers:
} }
) )
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404] self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
assert len(tokenizer) == 32001 self.assertEqual(len(tokenizer), 32001)
# ensure reloading the tokenizer again from cfg results in same vocab length # ensure reloading the tokenizer again from cfg results in same vocab length
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
assert len(tokenizer) == 32001 self.assertEqual(len(tokenizer), 32001)
def test_added_tokens_overrides(self, temp_dir):
cfg = DictDefault(
{
# use with tokenizer that has reserved_tokens in added_tokens
"tokenizer_config": "NousResearch/Llama-3.2-1B",
"added_tokens_overrides": {
128041: "RANDOM_OVERRIDE_1",
128042: "RANDOM_OVERRIDE_2",
},
"output_dir": temp_dir,
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer.encode("RANDOM_OVERRIDE_1", add_special_tokens=False) == [
128041
]
assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [
128042
]
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
cfg = DictDefault(
{
# use with tokenizer that has reserved_tokens in added_tokens
"tokenizer_config": "NousResearch/Llama-3.2-1B",
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
"output_dir": temp_dir,
}
)
with pytest.raises(
ValueError, match=r".*Token ID 1000000 not found in added_tokens.*"
):
load_tokenizer(cfg)
if __name__ == "__main__": if __name__ == "__main__":