Merge branch 'main' into telemetry-opt-in
This commit is contained in:
@@ -4,4 +4,4 @@ import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.10.0.dev0"
|
||||
__version__ = "0.13.0.dev"
|
||||
|
||||
@@ -4,5 +4,7 @@ import os
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
||||
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||
|
||||
configure_logging()
|
||||
|
||||
@@ -14,9 +14,13 @@ class PreprocessCliArgs:
|
||||
prompter: Optional[str] = field(default=None)
|
||||
download: Optional[bool] = field(default=True)
|
||||
iterable: Optional[bool] = field(
|
||||
default=None,
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use IterableDataset for streaming processing of large datasets"
|
||||
"help": (
|
||||
"Deprecated in v0.13.0, will be removed in v0.14.0. For streaming "
|
||||
"datasets, use 'axolotl train' and set 'streaming: true' in your YAML "
|
||||
"config, or pass --streaming instead in the CLI."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -30,8 +34,6 @@ class TrainerCliArgs:
|
||||
debug_num_examples: int = field(default=0)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
shard: bool = field(default=False)
|
||||
main_process_port: Optional[int] = field(default=None)
|
||||
num_processes: Optional[int] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -42,6 +44,12 @@ class VllmServeCliArgs:
|
||||
default=None,
|
||||
metadata={"help": "Number of tensor parallel workers to use."},
|
||||
)
|
||||
data_parallel_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference."
|
||||
},
|
||||
)
|
||||
host: Optional[str] = field(
|
||||
default=None, # nosec B104
|
||||
metadata={"help": "Host address to run the server on."},
|
||||
@@ -107,6 +115,7 @@ class QuantizeCliArgs:
|
||||
quantize_embedding: Optional[bool] = field(default=None)
|
||||
group_size: Optional[int] = field(default=None)
|
||||
output_dir: Optional[str] = field(default=None)
|
||||
hub_model_id: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -22,7 +22,7 @@ HAS_PRINTED_LOGO = False
|
||||
def print_axolotl_text_art():
|
||||
"""Prints axolotl ASCII art."""
|
||||
|
||||
global HAS_PRINTED_LOGO # pylint: disable=global-statement
|
||||
global HAS_PRINTED_LOGO
|
||||
if HAS_PRINTED_LOGO:
|
||||
return
|
||||
if is_main_process():
|
||||
|
||||
@@ -6,6 +6,7 @@ from pathlib import Path
|
||||
from accelerate.commands.config import config_args
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -46,3 +47,8 @@ def check_user_token() -> bool:
|
||||
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||
)
|
||||
return False
|
||||
except HTTPError:
|
||||
LOG.warning(
|
||||
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -3,16 +3,17 @@ launch axolotl in supported cloud platforms
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Literal
|
||||
|
||||
import yaml
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.cloud.base import Cloud
|
||||
from axolotl.cli.cloud.baseten import BasetenCloud
|
||||
from axolotl.cli.cloud.modal_ import ModalCloud
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
|
||||
def load_cloud_cfg(cloud_config: Path | str) -> DictDefault:
|
||||
"""Load and validate cloud configuration."""
|
||||
# Load cloud configuration.
|
||||
with open(cloud_config, encoding="utf-8") as file:
|
||||
@@ -21,10 +22,9 @@ def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
|
||||
|
||||
|
||||
def do_cli_preprocess(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
cloud_config: Path | str,
|
||||
config: Path | str,
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
@@ -33,28 +33,40 @@ def do_cli_preprocess(
|
||||
|
||||
|
||||
def do_cli_train(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
accelerate: bool = True,
|
||||
cloud_config: Path | str,
|
||||
config: Path | str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
cwd=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
|
||||
provider = cloud_cfg.provider or "modal"
|
||||
cloud: Cloud | None
|
||||
if provider == "modal":
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
elif provider == "baseten":
|
||||
cloud = BasetenCloud(cloud_cfg.to_dict())
|
||||
else:
|
||||
raise ValueError(f"Unsupported cloud provider: {provider}")
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
config_yaml = file.read()
|
||||
local_dirs = {}
|
||||
if cwd and not Path(cwd).joinpath("src", "axolotl").exists():
|
||||
local_dirs = {"/workspace/mounts": cwd}
|
||||
cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs)
|
||||
cloud.train(
|
||||
config_yaml,
|
||||
launcher=launcher,
|
||||
launcher_args=launcher_args,
|
||||
local_dirs=local_dirs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def do_cli_lm_eval(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
cloud_config: Path | str,
|
||||
config: Path | str,
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
|
||||
@@ -3,6 +3,7 @@ base class for cloud platforms from cli
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class Cloud(ABC):
|
||||
@@ -15,5 +16,12 @@ class Cloud(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def train(self, config_yaml: str, accelerate: bool = True) -> str:
|
||||
def train(
|
||||
self,
|
||||
config_yaml: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
local_dirs: dict[str, str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
pass
|
||||
|
||||
48
src/axolotl/cli/cloud/baseten/__init__.py
Normal file
48
src/axolotl/cli/cloud/baseten/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Baseten Cloud CLI"""
|
||||
|
||||
import shutil
|
||||
import subprocess # nosec B404
|
||||
import tempfile
|
||||
from os.path import dirname
|
||||
from typing import Literal
|
||||
|
||||
import yaml
|
||||
|
||||
from axolotl.cli.cloud.base import Cloud
|
||||
|
||||
|
||||
class BasetenCloud(Cloud):
|
||||
"""Baseten Cloud Axolotl CLI"""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
|
||||
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
|
||||
raise NotImplementedError(
|
||||
"Separate preprocess function for Baseten is not "
|
||||
"implemented and will happen during hte train step."
|
||||
)
|
||||
|
||||
def train(
|
||||
self,
|
||||
config_yaml: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
local_dirs: dict[str, str] | None = None, # pylint: disable=unused-argument
|
||||
**kwargs,
|
||||
):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config = self.config.copy()
|
||||
config["launcher"] = launcher
|
||||
config["launcher_args"] = launcher_args
|
||||
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
|
||||
yaml.dump(config, cloud_fout)
|
||||
with open(tmp_dir + "/train.yaml", "w", encoding="utf-8") as config_fout:
|
||||
config_fout.write(config_yaml)
|
||||
shutil.copyfile(dirname(__file__) + "/template/run.sh", tmp_dir + "/run.sh")
|
||||
shutil.copyfile(
|
||||
dirname(__file__) + "/template/train_sft.py", tmp_dir + "/train_sft.py"
|
||||
)
|
||||
subprocess.run( # nosec B603 B607
|
||||
["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False
|
||||
)
|
||||
9
src/axolotl/cli/cloud/baseten/template/run.sh
Normal file
9
src/axolotl/cli/cloud/baseten/template/run.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
set -eux
|
||||
|
||||
export NCCL_SOCKET_IFNAME="^docker0,lo"
|
||||
export NCCL_IB_DISABLE=0
|
||||
export NCCL_TIMEOUT=1800000
|
||||
|
||||
axolotl preprocess train.yaml
|
||||
axolotl train train.yaml --launcher ${AXOLOTL_LAUNCHER} ${AXOLOTL_LAUNCHER_ARGS}
|
||||
71
src/axolotl/cli/cloud/baseten/template/train_sft.py
Normal file
71
src/axolotl/cli/cloud/baseten/template/train_sft.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Baseten Training Script for Axolotl
|
||||
"""
|
||||
|
||||
# pylint: skip-file
|
||||
import yaml
|
||||
from truss.base import truss_config
|
||||
|
||||
# Import necessary classes from the Baseten Training SDK
|
||||
from truss_train import definitions
|
||||
|
||||
cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
|
||||
gpu = cloud_config.get("gpu", "h100")
|
||||
gpu_count = int(cloud_config.get("gpu_count", 1))
|
||||
node_count = int(cloud_config.get("node_count", 1))
|
||||
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
|
||||
secrets = cloud_config.get("secrets", [])
|
||||
launcher = cloud_config.get("launcher", "accelerate")
|
||||
launcher_args = cloud_config.get("launcher_args", [])
|
||||
script_name = "run.sh"
|
||||
|
||||
launcher_args_str = ""
|
||||
if launcher_args:
|
||||
launcher_args_str = "-- " + " ".join(launcher_args)
|
||||
|
||||
# 1. Define a base image for your training job
|
||||
# must use torch 2.7.0 for vllm
|
||||
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
|
||||
|
||||
# 2. Define the Runtime Environment for the Training Job
|
||||
# This includes start commands and environment variables.a
|
||||
# Secrets from the baseten workspace like API keys are referenced using
|
||||
# `SecretReference`.
|
||||
|
||||
env_vars = {
|
||||
"AXOLOTL_LAUNCHER": launcher,
|
||||
"AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
|
||||
}
|
||||
for secret_name in secrets:
|
||||
env_vars[secret_name] = definitions.SecretReference(name=secret_name)
|
||||
|
||||
training_runtime = definitions.Runtime(
|
||||
start_commands=[ # Example: list of commands to run your training script
|
||||
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
|
||||
],
|
||||
environment_variables=env_vars,
|
||||
)
|
||||
|
||||
# 3. Define the Compute Resources for the Training Job
|
||||
training_compute = definitions.Compute(
|
||||
node_count=node_count,
|
||||
accelerator=truss_config.AcceleratorSpec(
|
||||
accelerator=truss_config.Accelerator.H100,
|
||||
count=gpu_count,
|
||||
),
|
||||
)
|
||||
|
||||
# 4. Define the Training Job
|
||||
# This brings together the image, compute, and runtime configurations.
|
||||
my_training_job = definitions.TrainingJob(
|
||||
image=definitions.Image(base_image=BASE_IMAGE),
|
||||
compute=training_compute,
|
||||
runtime=training_runtime,
|
||||
)
|
||||
|
||||
|
||||
# This config will be pushed using the Truss CLI.
|
||||
# The association of the job to the project happens at the time of push.
|
||||
first_project_with_job = definitions.TrainingProject(
|
||||
name=project_name, job=my_training_job
|
||||
)
|
||||
@@ -8,7 +8,7 @@ import os
|
||||
import subprocess # nosec B404
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
from typing import Optional
|
||||
from typing import Literal
|
||||
|
||||
import modal
|
||||
|
||||
@@ -41,7 +41,7 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
|
||||
if exit_code := subprocess.call( # nosec B603
|
||||
cmd.split(), cwd=run_folder, env=new_env
|
||||
):
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
exit(exit_code)
|
||||
|
||||
# Commit writes to volume.
|
||||
if volumes:
|
||||
@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
|
||||
return res
|
||||
|
||||
def get_image(self):
|
||||
docker_tag = "main-py3.11-cu124-2.6.0"
|
||||
docker_tag = "main-py3.11-cu126-2.7.1"
|
||||
if self.config.docker_tag:
|
||||
docker_tag = self.config.docker_tag
|
||||
docker_image = f"axolotlai/axolotl:{docker_tag}"
|
||||
@@ -130,7 +130,6 @@ class ModalCloud(Cloud):
|
||||
res = []
|
||||
if self.config.secrets:
|
||||
for key in self.config.get("secrets", []):
|
||||
# pylint: disable=duplicate-code
|
||||
if isinstance(key, str):
|
||||
if val := os.environ.get(key, ""):
|
||||
res.append(modal.Secret.from_dict({key: val}))
|
||||
@@ -177,8 +176,8 @@ class ModalCloud(Cloud):
|
||||
with self.app.run(detach=True):
|
||||
modal_fn.remote(
|
||||
config_yaml,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
*args,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -187,7 +186,7 @@ class ModalCloud(Cloud):
|
||||
return int(self.config.timeout)
|
||||
return 60 * 60 * 24 # 24 hours
|
||||
|
||||
def get_train_gpu(self): # pylint: disable=too-many-return-statements
|
||||
def get_train_gpu(self):
|
||||
count = self.config.gpu_count or 1
|
||||
family = self.config.gpu.lower() or "l40s"
|
||||
|
||||
@@ -200,7 +199,7 @@ class ModalCloud(Cloud):
|
||||
if family in ["a10", "a10g"]:
|
||||
return modal.gpu.A10G(count=count)
|
||||
if family == "h100":
|
||||
return modal.gpu.H100(count=count)
|
||||
return f"H100:{count}"
|
||||
if family == "t4":
|
||||
return modal.gpu.T4(count=count)
|
||||
if family == "l4":
|
||||
@@ -230,8 +229,9 @@ class ModalCloud(Cloud):
|
||||
def train(
|
||||
self,
|
||||
config_yaml: str,
|
||||
accelerate: bool = True,
|
||||
local_dirs: Optional[dict[str, str]] = None,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
local_dirs: dict[str, str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
modal_fn = self.get_train_env(local_dirs)(_train)
|
||||
@@ -239,7 +239,8 @@ class ModalCloud(Cloud):
|
||||
with self.app.run(detach=True):
|
||||
modal_fn.remote(
|
||||
config_yaml,
|
||||
accelerate=accelerate,
|
||||
launcher=launcher,
|
||||
launcher_args=launcher_args,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
**kwargs,
|
||||
)
|
||||
@@ -270,20 +271,35 @@ def _preprocess(config_yaml: str, volumes=None):
|
||||
)
|
||||
|
||||
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||
def _train(
|
||||
config_yaml: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
volumes=None,
|
||||
**kwargs,
|
||||
):
|
||||
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
|
||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/mounts"
|
||||
if accelerate:
|
||||
accelerate_args = "--accelerate"
|
||||
|
||||
launcher_args = launcher_args or []
|
||||
|
||||
# Build the base command
|
||||
if launcher == "accelerate":
|
||||
launcher_arg = "--launcher accelerate"
|
||||
elif launcher == "torchrun":
|
||||
launcher_arg = "--launcher torchrun"
|
||||
else:
|
||||
accelerate_args = "--no-accelerate"
|
||||
num_processes_args = ""
|
||||
if num_processes := kwargs.pop("num_processes", None):
|
||||
num_processes_args = f"--num-processes {num_processes}"
|
||||
launcher_arg = "--launcher python"
|
||||
|
||||
# Build launcher args string
|
||||
launcher_args_str = ""
|
||||
if launcher_args:
|
||||
launcher_args_str = "-- " + " ".join(launcher_args)
|
||||
|
||||
run_cmd(
|
||||
f"axolotl train {accelerate_args} {num_processes_args} /workspace/mounts/config.yaml",
|
||||
f"axolotl train {launcher_arg} /workspace/mounts/config.yaml {launcher_args_str}".strip(),
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
@@ -25,10 +25,13 @@ from axolotl.utils.config import (
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||
from axolotl.utils.tee import prepare_debug_log
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
API_KEY_FIELDS = {"comet_api_key"}
|
||||
|
||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||
|
||||
@@ -155,6 +158,8 @@ def prepare_plugins(cfg: DictDefault):
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
for plugin_name in cfg["plugins"]:
|
||||
plugin_manager.register(plugin_name)
|
||||
for plugin in plugin_manager.plugins.values():
|
||||
plugin.register(cfg)
|
||||
|
||||
|
||||
def plugin_set_cfg(cfg: DictDefault):
|
||||
@@ -202,19 +207,18 @@ def load_cfg(
|
||||
# If there are any options passed in the cli, if it is something that seems valid
|
||||
# from the yaml, then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
for k, _ in kwargs.items():
|
||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||
if k in cfg_keys or not cfg.strict:
|
||||
# handle booleans
|
||||
if isinstance(cfg[k], bool):
|
||||
cfg[k] = bool(kwargs[k])
|
||||
for key, value in kwargs.items():
|
||||
# If not strict, allow writing to cfg even if it's not in the yml already
|
||||
if key in cfg_keys or not cfg.strict:
|
||||
if isinstance(cfg[key], bool):
|
||||
cfg[key] = bool(value)
|
||||
else:
|
||||
cfg[k] = kwargs[k]
|
||||
cfg[key] = value
|
||||
|
||||
try:
|
||||
device_props = torch.cuda.get_device_properties("cuda")
|
||||
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
||||
except: # pylint: disable=bare-except # noqa: E722
|
||||
except:
|
||||
gpu_version = None
|
||||
|
||||
prepare_plugins(cfg)
|
||||
@@ -231,8 +235,11 @@ def load_cfg(
|
||||
},
|
||||
)
|
||||
|
||||
# NOTE(djsaunde): We start outputting to output_dir/debug.log at this point since we
|
||||
# have to wait for cfg.output to be resolved. We could call this earlier if we write
|
||||
# to a temporary file, and then move it later.
|
||||
prepare_debug_log(cfg)
|
||||
prepare_optim_env(cfg)
|
||||
prepare_opinionated_env(cfg)
|
||||
normalize_config(cfg)
|
||||
normalize_cfg_datasets(cfg)
|
||||
setup_wandb_env_vars(cfg)
|
||||
@@ -241,5 +248,14 @@ def load_cfg(
|
||||
plugin_set_cfg(cfg)
|
||||
|
||||
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
|
||||
cfg_to_log = {
|
||||
k: "[REDACTED]" if k in API_KEY_FIELDS else v
|
||||
for k, v in cfg.items()
|
||||
if v is not None
|
||||
}
|
||||
LOG.info(
|
||||
"config:\n%s",
|
||||
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Generator, Union
|
||||
import fire
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from dotenv import load_dotenv
|
||||
from transformers import AutoProcessor
|
||||
|
||||
|
||||
@@ -86,9 +85,7 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
|
||||
unpatch_llama4 = patch_llama4_linearized_modeling()
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
|
||||
model_ = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model, torch_dtype=torch.bfloat16
|
||||
)
|
||||
model_ = Llama4ForConditionalGeneration.from_pretrained(model, dtype=torch.bfloat16)
|
||||
processor = AutoProcessor.from_pretrained(model)
|
||||
processor.save_pretrained(output)
|
||||
|
||||
@@ -152,5 +149,4 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -5,16 +5,13 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.evaluate import evaluate
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -31,11 +28,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
check_user_token()
|
||||
@@ -56,7 +49,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser(TrainerCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
@@ -66,5 +59,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -9,16 +9,18 @@ from typing import Union
|
||||
import fire
|
||||
import torch
|
||||
import transformers
|
||||
from dotenv import load_dotenv
|
||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
|
||||
from axolotl.cli.args import InferenceCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.cli.utils.diffusion import (
|
||||
diffusion_inference,
|
||||
launch_diffusion_gradio_ui,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.chat_templates import (
|
||||
get_chat_template,
|
||||
get_chat_template_from_config,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -35,10 +37,11 @@ def get_multi_line_input() -> str:
|
||||
Possibly multi-line, possibly empty stdin input as a string.
|
||||
"""
|
||||
print("Give me an instruction (Ctrl + D to submit): ")
|
||||
print("=" * 80)
|
||||
|
||||
instruction = ""
|
||||
for line in sys.stdin:
|
||||
instruction += line # pylint: disable=consider-using-join
|
||||
instruction += line
|
||||
|
||||
return instruction
|
||||
|
||||
@@ -50,9 +53,9 @@ def do_inference(
|
||||
cli_args: InferenceCliArgs,
|
||||
):
|
||||
"""
|
||||
Runs inference on the command line in a loop. User input is accepted, a chat template
|
||||
is (optionally) applied, and the model specified in the `axolotl` config is used to
|
||||
generate completions according to a default generation config.
|
||||
Runs inference on the command line in a loop. User input is accepted, a chat
|
||||
template is (optionally) applied, and the model specified in the `axolotl` config is
|
||||
used to generate completions according to a default generation config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
@@ -68,17 +71,31 @@ def do_inference(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
elif cfg.chat_template:
|
||||
chat_template_str = get_chat_template(cfg.chat_template)
|
||||
elif cfg.datasets[0].type == "chat_template":
|
||||
chat_template_str = get_chat_template_from_config(
|
||||
cfg, ds_cfg=None, tokenizer=tokenizer
|
||||
)
|
||||
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
|
||||
chat_template_str = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
||||
)
|
||||
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
# Detect diffusion mode
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
is_diffusion = any(
|
||||
plugin.__class__.__name__ == "DiffusionPlugin"
|
||||
for plugin in plugin_manager.plugins.values()
|
||||
)
|
||||
|
||||
if is_diffusion:
|
||||
print("=" * 80)
|
||||
print("Commands:")
|
||||
print(":complete N -> completion mode with N tokens (default 64)")
|
||||
print(":mask R -> random masking with ratio R (0.0–1.0)")
|
||||
|
||||
while True:
|
||||
print("=" * 80)
|
||||
# support for multiline inputs
|
||||
instruction = get_multi_line_input()
|
||||
if not instruction:
|
||||
return
|
||||
@@ -108,9 +125,19 @@ def do_inference(
|
||||
else:
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
print("=" * 40)
|
||||
print("=" * 80)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
if is_diffusion:
|
||||
diffusion_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
prompt=prompt,
|
||||
chat_template_str=chat_template_str,
|
||||
)
|
||||
continue
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=1024,
|
||||
@@ -133,7 +160,7 @@ def do_inference(
|
||||
generation_config=generation_config,
|
||||
streamer=streamer,
|
||||
)
|
||||
print("=" * 40)
|
||||
print("=" * 80)
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
@@ -164,15 +191,37 @@ def do_inference_gradio(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
elif cfg.chat_template:
|
||||
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
|
||||
chat_template_str = get_chat_template_from_config(
|
||||
cfg, ds_cfg=None, tokenizer=tokenizer
|
||||
)
|
||||
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
|
||||
chat_template_str = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
||||
)
|
||||
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
# Detect diffusion mode
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
is_diffusion = any(
|
||||
plugin.__class__.__name__ == "DiffusionPlugin"
|
||||
for plugin in plugin_manager.plugins.values()
|
||||
)
|
||||
|
||||
if is_diffusion:
|
||||
launch_diffusion_gradio_ui(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
prompter_module=prompter_module,
|
||||
chat_template_str=chat_template_str,
|
||||
)
|
||||
return
|
||||
|
||||
def generate(instruction):
|
||||
if not instruction:
|
||||
return
|
||||
if prompter_module:
|
||||
# pylint: disable=stop-iteration-return
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
@@ -257,8 +306,7 @@ def do_cli(
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
|
||||
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
||||
parsed_cfg.sample_packing = False
|
||||
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
||||
@@ -273,5 +321,4 @@ def do_cli(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
"""Click CLI definitions for various axolotl commands."""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import os
|
||||
import subprocess # nosec B404
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import axolotl
|
||||
@@ -20,26 +15,36 @@ from axolotl.cli.args import (
|
||||
TrainerCliArgs,
|
||||
VllmServeCliArgs,
|
||||
)
|
||||
from axolotl.cli.sweeps import generate_sweep_configs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.utils import (
|
||||
add_options_from_config,
|
||||
add_options_from_dataclass,
|
||||
build_command,
|
||||
fetch_from_github,
|
||||
filter_none_kwargs,
|
||||
generate_config_files,
|
||||
launch_training,
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
LAUNCHER_COMMAND_MAPPING = {
|
||||
"accelerate": ["accelerate", "launch"],
|
||||
"torchrun": ["torchrun"],
|
||||
}
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||
def cli():
|
||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||
print_axolotl_text_art()
|
||||
load_dotenv()
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -48,7 +53,7 @@ def cli():
|
||||
@add_options_from_dataclass(PreprocessCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Preprocess datasets before training.
|
||||
|
||||
@@ -58,7 +63,6 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
patch_optimized_env()
|
||||
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_preprocess
|
||||
@@ -70,12 +74,15 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
help="Use accelerate launch for multi-GPU training",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for multi-GPU training",
|
||||
)
|
||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
@@ -86,126 +93,82 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
@click.pass_context
|
||||
def train(
|
||||
ctx: click.Context,
|
||||
config: str,
|
||||
accelerate: bool,
|
||||
cloud: Optional[str] = None,
|
||||
sweep: Optional[str] = None,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
cloud: str | None = None,
|
||||
sweep: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Train or fine-tune a model.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for multi-GPU training ("accelerate", "torchrun", or "python").
|
||||
cloud: Path to a cloud accelerator configuration file
|
||||
sweep: Path to YAML config for sweeping hyperparameters.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||
accelerate = False
|
||||
if sweep:
|
||||
# load the sweep configuration yaml file
|
||||
with open(sweep, "r", encoding="utf-8") as fin:
|
||||
sweep_config: dict[str, list] = yaml.safe_load(fin)
|
||||
with open(config, "r", encoding="utf-8") as fin:
|
||||
base_config: dict[str, list] = yaml.safe_load(fin)
|
||||
# Handle Ray launcher override
|
||||
_launcher = None if kwargs.get("use_ray") else launcher
|
||||
|
||||
# generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
|
||||
def iter_configs():
|
||||
for perm in permutations:
|
||||
# open temp directory for temporary configurations
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with open(
|
||||
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
|
||||
) as fout:
|
||||
yaml.dump(perm, fout)
|
||||
yield str(Path(temp_dir) / "config.yaml")
|
||||
|
||||
else:
|
||||
|
||||
def iter_configs():
|
||||
yield config
|
||||
|
||||
for cfg_file in iter_configs():
|
||||
# handle errors from subprocess so we can continue rest of sweeps
|
||||
# Process each configuration
|
||||
for cfg_file, is_group in generate_config_files(config, sweep):
|
||||
try:
|
||||
if accelerate:
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
cwd = os.getcwd()
|
||||
do_cli_train(
|
||||
cloud_config=cloud,
|
||||
config=config,
|
||||
accelerate=True,
|
||||
cwd=cwd,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
accelerate_args = []
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port", None)
|
||||
accelerate_args.append("--main_process_port")
|
||||
accelerate_args.append(str(main_process_port))
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes", None)
|
||||
accelerate_args.append("--num_processes")
|
||||
accelerate_args.append(str(num_processes))
|
||||
|
||||
base_cmd = ["accelerate", "launch"]
|
||||
base_cmd.extend(accelerate_args)
|
||||
base_cmd.extend(["-m", "axolotl.cli.train"])
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
do_cli_train(
|
||||
cloud_config=cloud, config=config, accelerate=False, **kwargs
|
||||
)
|
||||
else:
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
do_cli(config=cfg_file, **kwargs)
|
||||
use_exec = is_group is not True
|
||||
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||
if not sweep:
|
||||
raise exc
|
||||
finally:
|
||||
# Only delete temp files, not the original config
|
||||
if cfg_file != config:
|
||||
os.unlink(cfg_file)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
help="Use accelerate launch for multi-GPU training",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for multi-GPU evaluation",
|
||||
)
|
||||
@add_options_from_dataclass(EvaluateCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def evaluate(config: str, accelerate: bool, **kwargs) -> None:
|
||||
@click.pass_context
|
||||
def evaluate(ctx: click.Context, config: str, launcher: str, **kwargs):
|
||||
"""
|
||||
Evaluate a model.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for multi-GPU evaluation ("accelerate", "torchrun", or "python").
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if launcher in LAUNCHER_COMMAND_MAPPING:
|
||||
base_cmd = (
|
||||
LAUNCHER_COMMAND_MAPPING[launcher]
|
||||
+ launcher_args
|
||||
+ ["-m", "axolotl.cli.evaluate"]
|
||||
)
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
@@ -216,30 +179,42 @@ def evaluate(config: str, accelerate: bool, **kwargs) -> None:
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=False,
|
||||
help="Use accelerate launch for multi-GPU inference",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for multi-GPU inference",
|
||||
)
|
||||
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
|
||||
@click.pass_context
|
||||
def inference(ctx: click.Context, config: str, launcher: str, gradio: bool, **kwargs):
|
||||
"""
|
||||
Run inference with a trained model.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for multi-GPU inference ("accelerate", "torchrun", or "python").
|
||||
gradio: Whether to use Gradio browser interface or command line for inference.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if launcher in LAUNCHER_COMMAND_MAPPING:
|
||||
base_cmd = (
|
||||
LAUNCHER_COMMAND_MAPPING[launcher]
|
||||
+ launcher_args
|
||||
+ ["-m", "axolotl.cli.inference"]
|
||||
)
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
if gradio:
|
||||
@@ -252,33 +227,42 @@ def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
|
||||
do_cli(config=config, gradio=gradio, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
help="Use accelerate launch for weight merging",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for weight merging",
|
||||
)
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
|
||||
@click.pass_context
|
||||
def merge_sharded_fsdp_weights(
|
||||
ctx: click.Context, config: str, launcher: str, **kwargs
|
||||
):
|
||||
"""
|
||||
Merge sharded FSDP model weights.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for weight merging ("accelerate", "torchrun", or "python").
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
if accelerate:
|
||||
base_cmd = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"-m",
|
||||
"axolotl.cli.merge_sharded_fsdp_weights",
|
||||
]
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if launcher in LAUNCHER_COMMAND_MAPPING:
|
||||
base_cmd = (
|
||||
LAUNCHER_COMMAND_MAPPING[launcher]
|
||||
+ launcher_args
|
||||
+ ["-m", "axolotl.cli.merge_sharded_fsdp_weights"]
|
||||
)
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
@@ -294,7 +278,7 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def merge_lora(config: str, **kwargs) -> None:
|
||||
def merge_lora(config: str, **kwargs):
|
||||
"""
|
||||
Merge trained LoRA adapters into a base model.
|
||||
|
||||
@@ -311,7 +295,7 @@ def merge_lora(config: str, **kwargs) -> None:
|
||||
@cli.command()
|
||||
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
|
||||
@click.option("--dest", help="Destination directory")
|
||||
def fetch(directory: str, dest: Optional[str]) -> None:
|
||||
def fetch(directory: str, dest: Optional[str]):
|
||||
"""
|
||||
Fetch example configs or other resources.
|
||||
|
||||
@@ -349,7 +333,7 @@ def quantize(config: str, **cli_args: QuantizeCliArgs):
|
||||
@cli.command()
|
||||
@click.argument("model", type=click.Path(exists=True, path_type=str))
|
||||
@click.argument("output", type=click.Path(exists=False, path_type=str))
|
||||
def delinearize_llama4(model: str, output: str) -> None:
|
||||
def delinearize_llama4(model: str, output: str):
|
||||
from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4
|
||||
|
||||
do_delinearize_llama4(model, output)
|
||||
@@ -363,5 +347,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
main()
|
||||
|
||||
@@ -4,9 +4,7 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
@@ -25,8 +23,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
|
||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
@@ -49,7 +45,10 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
safe_serialization=safe_serialization,
|
||||
progressbar=True,
|
||||
)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
tokenizer.save_pretrained(
|
||||
str(Path(cfg.output_dir) / "merged"),
|
||||
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
||||
)
|
||||
|
||||
if processor:
|
||||
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
@@ -75,7 +74,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
load_in_8bit=False,
|
||||
load_in_4bit=False,
|
||||
flash_attention=False,
|
||||
sequence_parallel_degree=None,
|
||||
context_parallel_size=None,
|
||||
deepspeed=None,
|
||||
fsdp=None,
|
||||
fsdp_config=None,
|
||||
@@ -93,5 +92,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -10,6 +10,7 @@ import fire
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
@@ -17,15 +18,14 @@ from accelerate.utils import (
|
||||
WEIGHTS_NAME,
|
||||
is_torch_version,
|
||||
)
|
||||
from dotenv import load_dotenv
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.train import determine_last_checkpoint
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -33,7 +33,7 @@ LOG = get_logger(__name__)
|
||||
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
|
||||
"""A custom planner to cast tensors to bfloat16 on the fly during loading."""
|
||||
|
||||
def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument
|
||||
def commit_tensor(self, read_item, tensor):
|
||||
tensor.copy_(tensor.to(torch.bfloat16))
|
||||
|
||||
|
||||
@@ -60,10 +60,10 @@ def _distributed_checkpoint_to_merged_weights(
|
||||
state_dict: Dict = {}
|
||||
save_path_ = Path(save_path)
|
||||
save_path_.mkdir(exist_ok=True)
|
||||
dist_cp_format_utils._load_state_dict( # pylint: disable=protected-access
|
||||
dist_cp_format_utils._load_state_dict(
|
||||
state_dict,
|
||||
storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
|
||||
planner=BFloat16CastPlanner(), # pylint: disable=protected-access
|
||||
planner=BFloat16CastPlanner(),
|
||||
no_dist=True,
|
||||
)
|
||||
|
||||
@@ -147,7 +147,6 @@ def merge_fsdp_weights(
|
||||
ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.
|
||||
"""
|
||||
checkpoint_dir_ = Path(checkpoint_dir)
|
||||
from accelerate.state import PartialState
|
||||
|
||||
if not is_torch_version(">=", "2.3.0"):
|
||||
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
|
||||
@@ -184,7 +183,6 @@ def merge_fsdp_weights(
|
||||
if remove_checkpoint_dir:
|
||||
LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
|
||||
shutil.rmtree(checkpoint_dir_)
|
||||
state.wait_for_everyone()
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
@@ -195,18 +193,37 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
|
||||
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
||||
if not fsdp_dir.exists():
|
||||
checkpoint_dir = determine_last_checkpoint(parsed_cfg, update=False)
|
||||
if checkpoint_dir:
|
||||
fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0"
|
||||
if not fsdp_dir.exists():
|
||||
raise ValueError(
|
||||
f"Could not find FSDP checkpoint `pytorch_model_fsdp_0` in {checkpoint_dir}"
|
||||
)
|
||||
|
||||
output_path = str(Path(parsed_cfg.output_dir) / "merged")
|
||||
merge_fsdp_weights(
|
||||
checkpoint_dir=str(fsdp_dir),
|
||||
output_path=str(Path(parsed_cfg.output_dir) / "merged"),
|
||||
output_path=output_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
state = PartialState()
|
||||
state.wait_for_everyone()
|
||||
LOG.info(
|
||||
f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}",
|
||||
main_process_only=True,
|
||||
)
|
||||
LOG.info(
|
||||
"Merged weights are only the safetensors and doesn't include the model configuration "
|
||||
f"or tokenizer which may be found in {parsed_cfg.output_dir}.",
|
||||
main_process_only=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""CLI to run preprocessing of a dataset."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
@@ -8,11 +9,9 @@ import fire
|
||||
import transformers
|
||||
from accelerate import init_empty_weights
|
||||
from colorama import Fore
|
||||
from dotenv import load_dotenv
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.cli.args import PreprocessCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
@@ -35,10 +34,26 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Preprocessing-specific CLI arguments.
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
|
||||
if cli_args.iterable:
|
||||
LOG.error(
|
||||
"The --iterable CLI argument for 'axolotl preprocess' is no longer "
|
||||
"supported. For training, set 'streaming: true' in your YAML config or "
|
||||
"pass '--streaming' in your 'axolotl train' command for on-the-fly "
|
||||
"preprocessing."
|
||||
)
|
||||
return
|
||||
|
||||
for key in ["skip_prepare_dataset", "pretraining_dataset"]:
|
||||
if cfg.get(key):
|
||||
LOG.error(
|
||||
f"You have set `{key}:`. `preprocess` is not needed. Run the 'axolotl "
|
||||
"train' CLI directly instead."
|
||||
)
|
||||
return
|
||||
|
||||
if not cfg.dataset_prepared_path:
|
||||
msg = (
|
||||
Fore.RED
|
||||
@@ -70,7 +85,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||
AutoModelForCausalLM.from_pretrained(
|
||||
model_name, trust_remote_code=True
|
||||
)
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught,unused-variable # nosec B110 # noqa F841
|
||||
except Exception: # nosec B110
|
||||
pass
|
||||
# fmt: on
|
||||
|
||||
@@ -92,8 +107,10 @@ def do_cli(
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
|
||||
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
|
||||
is_preprocess = kwargs.pop("is_preprocess", True)
|
||||
parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs)
|
||||
parsed_cfg.is_preprocess = True
|
||||
parser = transformers.HfArgumentParser(PreprocessCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
@@ -104,5 +121,4 @@ def do_cli(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -5,13 +5,17 @@ CLI to post-training quantize a model using torchao
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.loaders import load_tokenizer
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
|
||||
from axolotl.utils.quantization import (
|
||||
TorchAOQuantDType,
|
||||
get_quantization_config,
|
||||
quantization_config_to_str,
|
||||
quantize_model,
|
||||
)
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -27,7 +31,6 @@ def do_quantize(
|
||||
config (Union[Path, str]): The path to the config file
|
||||
cli_args (dict): Additional command-line arguments
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
|
||||
cfg = load_cfg(config)
|
||||
|
||||
@@ -45,13 +48,13 @@ def do_quantize(
|
||||
"No quantization configuration found. Please specify either qat or quantization in your config file."
|
||||
)
|
||||
|
||||
model_path = cli_args.get("model_path") or cfg.output_dir
|
||||
model_path = cli_args.get("base_model") or cfg.output_dir
|
||||
if weight_dtype := cli_args.get("weight_dtype"):
|
||||
weight_dtype = TorchIntDType[weight_dtype]
|
||||
weight_dtype = TorchAOQuantDType.from_string(weight_dtype)
|
||||
else:
|
||||
weight_dtype = quantize_cfg.weight_dtype
|
||||
if activation_dtype := cli_args.get("activation_dtype"):
|
||||
activation_dtype = TorchIntDType[activation_dtype]
|
||||
activation_dtype = TorchAOQuantDType.from_string(activation_dtype)
|
||||
else:
|
||||
activation_dtype = quantize_cfg.activation_dtype
|
||||
group_size = cli_args.get("group_size") or quantize_cfg.group_size
|
||||
@@ -59,10 +62,15 @@ def do_quantize(
|
||||
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
|
||||
)
|
||||
output_dir = cli_args.get("output_dir") or cfg.output_dir
|
||||
hub_model_id = cli_args.get("hub_model_id") or cfg.hub_model_id
|
||||
|
||||
LOG.info(f"Loading model from {model_path}...")
|
||||
LOG.info(f"Loading model from {model_path}.")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, device_map="auto", dtype=torch_dtype
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
f"Quantizing model with configuration: \n"
|
||||
@@ -72,11 +80,21 @@ def do_quantize(
|
||||
f"\tquantize_embedding: {quantize_embedding}"
|
||||
)
|
||||
|
||||
quantize_model_for_ptq(
|
||||
quantize_model(
|
||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||
)
|
||||
|
||||
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
|
||||
quantization_config = get_quantization_config(
|
||||
weight_dtype, activation_dtype, group_size
|
||||
)
|
||||
|
||||
ao_config = TorchAoConfig(
|
||||
quant_type=quantization_config,
|
||||
include_input_output_embeddings=quantize_embedding,
|
||||
)
|
||||
model.config.quantization_config = ao_config
|
||||
|
||||
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
|
||||
model.save_pretrained(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
safe_serialization=False,
|
||||
@@ -86,5 +104,16 @@ def do_quantize(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
safe_serialization=False,
|
||||
progressbar=True,
|
||||
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
||||
)
|
||||
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")
|
||||
|
||||
if hub_model_id:
|
||||
hub_model_id = (
|
||||
hub_model_id.rstrip("-")
|
||||
+ f"-{quantization_config_to_str[type(quantization_config)]}"
|
||||
)
|
||||
model.push_to_hub(hub_model_id, safe_serialization=False)
|
||||
tokenizer.push_to_hub(hub_model_id)
|
||||
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
|
||||
|
||||
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")
|
||||
|
||||
@@ -7,19 +7,17 @@ from typing import Union
|
||||
|
||||
import fire
|
||||
from accelerate import Accelerator
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
|
||||
|
||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
@@ -32,10 +30,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Training-specific CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
check_user_token()
|
||||
@@ -66,7 +60,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser(TrainerCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
@@ -99,23 +92,30 @@ def ray_train_func(kwargs: dict):
|
||||
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
|
||||
# also renormalize the config now that TorchTrainer has spawned distributed workers
|
||||
cfg = DictDefault(kwargs["cfg"])
|
||||
prepare_optim_env(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype
|
||||
resolve_dtype(cfg)
|
||||
|
||||
# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
|
||||
if cfg.deepspeed:
|
||||
if cfg.deepspeed and hasattr(cfg.deepspeed, "to_dict"):
|
||||
cfg.deepspeed = cfg.deepspeed.to_dict()
|
||||
|
||||
# initialize accelerator before model instantiation
|
||||
Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
|
||||
|
||||
# Register plugins in Ray workers
|
||||
if cfg.get("plugins"):
|
||||
from axolotl.cli.config import plugin_set_cfg, prepare_plugins
|
||||
|
||||
prepare_plugins(cfg)
|
||||
plugin_set_cfg(cfg)
|
||||
|
||||
kwargs["cfg"] = cfg
|
||||
|
||||
do_train(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -1,330 +0,0 @@
|
||||
"""Utility methods for axolotl CLI."""
|
||||
|
||||
import concurrent.futures
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import json
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from types import NoneType
|
||||
from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||
|
||||
import click
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.loaders.model import ModelLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def strip_optional_type(field_type: type | str | None):
|
||||
"""
|
||||
Extracts the non-`None` type from an `Optional` / `Union` type.
|
||||
|
||||
Args:
|
||||
field_type: Type of field for Axolotl CLI command.
|
||||
|
||||
Returns:
|
||||
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
|
||||
returns the input type unchanged.
|
||||
"""
|
||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||
field_type = next(
|
||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||
)
|
||||
|
||||
return field_type
|
||||
|
||||
|
||||
def filter_none_kwargs(func: Callable) -> Callable:
|
||||
"""
|
||||
Wraps function to remove `None`-valued `kwargs`.
|
||||
|
||||
Args:
|
||||
func: Function to wrap.
|
||||
|
||||
Returns:
|
||||
Wrapped function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Callable:
|
||||
"""Filters out `None`-valued `kwargs`."""
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
return func(*args, **filtered_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a dataclass.
|
||||
|
||||
Args:
|
||||
config_class: Dataclass with fields to parse from the CLI.
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process dataclass fields in reverse order for correct option ordering
|
||||
for field in reversed(dataclasses.fields(config_class)):
|
||||
field_type = strip_optional_type(field.type)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = field.name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{field.name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a Pydantic model.
|
||||
|
||||
Args:
|
||||
config_class: PyDantic model with fields to parse from the CLI
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process model fields in reverse order for correct option ordering
|
||||
for name, field in reversed(config_class.model_fields.items()):
|
||||
field_type = strip_optional_type(field.annotation)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Build command list from base command and options.
|
||||
|
||||
Args:
|
||||
base_cmd: Command without options.
|
||||
options: Options to parse and append to base command.
|
||||
|
||||
Returns:
|
||||
List of strings giving shell command.
|
||||
"""
|
||||
cmd = base_cmd.copy()
|
||||
|
||||
for key, value in options.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
key = key.replace("_", "-")
|
||||
|
||||
if isinstance(value, bool):
|
||||
if value:
|
||||
cmd.append(f"--{key}")
|
||||
else:
|
||||
cmd.extend([f"--{key}", str(value)])
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def download_file(
|
||||
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Download a single file and return its processing status.
|
||||
|
||||
Args:
|
||||
file_info: Tuple of (file_path, remote_sha).
|
||||
raw_base_url: Base URL for raw GitHub content.
|
||||
dest_path: Local destination directory.
|
||||
dir_prefix: Directory prefix to filter files.
|
||||
|
||||
Returns:
|
||||
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
|
||||
"""
|
||||
file_path, remote_sha = file_info
|
||||
raw_url = f"{raw_base_url}/{file_path}"
|
||||
dest_file = dest_path / file_path.split(dir_prefix)[-1]
|
||||
|
||||
# Check if file exists and needs updating
|
||||
if dest_file.exists():
|
||||
with open(dest_file, "rb") as file:
|
||||
content = file.read()
|
||||
# Calculate git blob SHA
|
||||
blob = b"blob " + str(len(content)).encode() + b"\0" + content
|
||||
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
|
||||
|
||||
if local_sha == remote_sha:
|
||||
print(f"Skipping {file_path} (unchanged)")
|
||||
return file_path, "unchanged"
|
||||
|
||||
print(f"Updating {file_path}")
|
||||
status = "new"
|
||||
else:
|
||||
print(f"Downloading {file_path}")
|
||||
status = "new"
|
||||
|
||||
# Create directories if needed
|
||||
dest_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download and save file
|
||||
try:
|
||||
response = requests.get(raw_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(dest_file, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
return file_path, status
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error downloading {file_path}: {str(request_error)}")
|
||||
return file_path, "error"
|
||||
|
||||
|
||||
def fetch_from_github(
|
||||
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
|
||||
) -> None:
|
||||
"""
|
||||
Sync files from a specific directory in the GitHub repository.
|
||||
Only downloads files that don't exist locally or have changed.
|
||||
|
||||
Args:
|
||||
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
|
||||
'deepspeed_configs/').
|
||||
dest_dir: Local destination directory.
|
||||
max_workers: Maximum number of concurrent downloads.
|
||||
"""
|
||||
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
|
||||
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
|
||||
|
||||
# Get repository tree with timeout
|
||||
response = requests.get(api_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
tree = json.loads(response.text)
|
||||
|
||||
# Filter for files and get their SHA
|
||||
files = {
|
||||
item["path"]: item["sha"]
|
||||
for item in tree["tree"]
|
||||
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
|
||||
}
|
||||
|
||||
if not files:
|
||||
raise click.ClickException(f"No files found in {dir_prefix}")
|
||||
|
||||
# Default destination directory is the last part of dir_prefix
|
||||
default_dest = Path(dir_prefix.rstrip("/"))
|
||||
dest_path = Path(dest_dir) if dest_dir else default_dest
|
||||
|
||||
# Keep track of processed files for summary
|
||||
files_processed: dict[str, list[str]] = {
|
||||
"new": [],
|
||||
"updated": [],
|
||||
"unchanged": [],
|
||||
"error": [],
|
||||
}
|
||||
|
||||
# Process files in parallel using ThreadPoolExecutor
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_file = {
|
||||
executor.submit(
|
||||
download_file,
|
||||
(file_path, remote_sha),
|
||||
raw_base_url,
|
||||
dest_path,
|
||||
dir_prefix,
|
||||
): file_path
|
||||
for file_path, remote_sha in files.items()
|
||||
}
|
||||
|
||||
# Process completed tasks as they finish
|
||||
for future in concurrent.futures.as_completed(future_to_file):
|
||||
file_path = future_to_file[future]
|
||||
try:
|
||||
file_path, status = future.result()
|
||||
files_processed[status].append(file_path)
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error processing {file_path}: {str(request_error)}")
|
||||
files_processed["error"].append(file_path)
|
||||
|
||||
# Log summary
|
||||
LOG.info("\nSync Summary:")
|
||||
LOG.info(f"New files: {len(files_processed['new'])}")
|
||||
LOG.info(f"Updated files: {len(files_processed['updated'])}")
|
||||
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
|
||||
if files_processed["error"]:
|
||||
LOG.info(f"Failed files: {len(files_processed['error'])}")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
) -> tuple[
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
|
||||
ProcessorMixin | None,
|
||||
]:
|
||||
"""
|
||||
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
|
||||
config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
inference: Boolean denoting inference mode.
|
||||
|
||||
Returns:
|
||||
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
|
||||
"""
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
LOG.info("loading model...")
|
||||
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
|
||||
model, _ = model_loader.load()
|
||||
|
||||
processor = None
|
||||
if cfg.is_multimodal:
|
||||
LOG.info("loading processor...")
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
return model, tokenizer, processor
|
||||
23
src/axolotl/cli/utils/__init__.py
Normal file
23
src/axolotl/cli/utils/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Init for axolotl.cli.utils module."""
|
||||
|
||||
from .args import (
|
||||
add_options_from_config,
|
||||
add_options_from_dataclass,
|
||||
filter_none_kwargs,
|
||||
)
|
||||
from .fetch import fetch_from_github
|
||||
from .load import load_model_and_tokenizer
|
||||
from .sweeps import generate_sweep_configs
|
||||
from .train import build_command, generate_config_files, launch_training
|
||||
|
||||
__all__ = [
|
||||
"filter_none_kwargs",
|
||||
"add_options_from_dataclass",
|
||||
"add_options_from_config",
|
||||
"build_command",
|
||||
"generate_config_files",
|
||||
"generate_sweep_configs",
|
||||
"load_model_and_tokenizer",
|
||||
"launch_training",
|
||||
"fetch_from_github",
|
||||
]
|
||||
120
src/axolotl/cli/utils/args.py
Normal file
120
src/axolotl/cli/utils/args.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Utilities for axolotl CLI args."""
|
||||
|
||||
import dataclasses
|
||||
from functools import wraps
|
||||
from types import NoneType
|
||||
from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||
|
||||
import click
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def _strip_optional_type(field_type: type | str | None):
|
||||
"""
|
||||
Extracts the non-`None` type from an `Optional` / `Union` type.
|
||||
|
||||
Args:
|
||||
field_type: Type of field for Axolotl CLI command.
|
||||
|
||||
Returns:
|
||||
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
|
||||
returns the input type unchanged.
|
||||
"""
|
||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||
field_type = next(
|
||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||
)
|
||||
|
||||
return field_type
|
||||
|
||||
|
||||
def filter_none_kwargs(func: Callable) -> Callable:
|
||||
"""
|
||||
Wraps function to remove `None`-valued `kwargs`.
|
||||
|
||||
Args:
|
||||
func: Function to wrap.
|
||||
|
||||
Returns:
|
||||
Wrapped function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Callable:
|
||||
"""Filters out `None`-valued `kwargs`."""
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
return func(*args, **filtered_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a dataclass.
|
||||
|
||||
Args:
|
||||
config_class: Dataclass with fields to parse from the CLI.
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process dataclass fields in reverse order for correct option ordering
|
||||
for field in reversed(dataclasses.fields(config_class)):
|
||||
field_type = _strip_optional_type(field.type)
|
||||
|
||||
if field_type is bool:
|
||||
field_name = field.name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{field.name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a Pydantic model.
|
||||
|
||||
Args:
|
||||
config_class: PyDantic model with fields to parse from the CLI
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process model fields in reverse order for correct option ordering
|
||||
for name, field in reversed(config_class.model_fields.items()):
|
||||
field_type = _strip_optional_type(field.annotation)
|
||||
|
||||
if field_type is bool:
|
||||
field_name = name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
374
src/axolotl/cli/utils/diffusion.py
Normal file
374
src/axolotl/cli/utils/diffusion.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""Helpers for diffusion-mode inference in CLI and Gradio."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gradio as gr
|
||||
from colorama import Fore, Style
|
||||
|
||||
from axolotl.integrations.diffusion import generate, resolve_mask_token_id
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def diffusion_inference(
|
||||
model,
|
||||
tokenizer,
|
||||
cfg,
|
||||
prompt: str,
|
||||
chat_template_str: str | None = None,
|
||||
):
|
||||
"""Diffusion inference helper method."""
|
||||
mode = "random"
|
||||
completion_tokens = 0
|
||||
target_mask_ratio = None
|
||||
mode, completion_tokens, target_mask_ratio, cleaned = _parse_commands(prompt)
|
||||
|
||||
if cleaned:
|
||||
prompt = cleaned
|
||||
|
||||
info = run_diffusion(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
prompt=prompt,
|
||||
chat_template_str=chat_template_str,
|
||||
mode=mode,
|
||||
target_mask_ratio=target_mask_ratio,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
masked_text = info["masked_text"]
|
||||
mask_ratio = info["mask_ratio"]
|
||||
generated_ids = info["generated_ids"]
|
||||
masked_positions = info["masked_positions"]
|
||||
orig_ids = info["orig_ids"]
|
||||
|
||||
# Display with masked preview and colored diff
|
||||
if masked_text is not None and mask_ratio is not None:
|
||||
print(f"Masked ({mask_ratio:.1%}):\n{masked_text}\n")
|
||||
if generated_ids is not None:
|
||||
# Compute per-token style
|
||||
styles: list[str] = []
|
||||
for i, tid in enumerate(generated_ids):
|
||||
if i in masked_positions:
|
||||
if i < len(orig_ids) and tid == orig_ids[i]:
|
||||
styles.append("green") # correct fill
|
||||
elif i < len(orig_ids):
|
||||
styles.append("red") # incorrect fill
|
||||
else:
|
||||
styles.append("normal") # appended
|
||||
else:
|
||||
same = i < len(orig_ids) and tid == orig_ids[i]
|
||||
styles.append("dim" if same else "normal")
|
||||
|
||||
# Group contiguous spans by style
|
||||
styled_spans: list[tuple[str, int, int]] = []
|
||||
if generated_ids:
|
||||
current_style = styles[0]
|
||||
start = 0
|
||||
for i in range(1, len(generated_ids)):
|
||||
s = styles[i]
|
||||
if s != current_style:
|
||||
styled_spans.append((current_style, start, i))
|
||||
current_style, start = s, i
|
||||
styled_spans.append((current_style, start, len(generated_ids)))
|
||||
|
||||
out_parts = []
|
||||
for style_name, a, b in styled_spans:
|
||||
chunk_text = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False)
|
||||
if style_name == "green":
|
||||
out_parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL)
|
||||
elif style_name == "red":
|
||||
out_parts.append(Fore.RED + chunk_text + Style.RESET_ALL)
|
||||
else:
|
||||
if style_name == "dim":
|
||||
out_parts.append(Style.DIM + chunk_text + Style.RESET_ALL)
|
||||
else:
|
||||
out_parts.append(chunk_text)
|
||||
print("Generated:\n" + "".join(out_parts))
|
||||
else:
|
||||
print("Generated:\n(no output)")
|
||||
|
||||
|
||||
def _parse_commands(text: str):
|
||||
"""
|
||||
Parse leading diffusion commands.
|
||||
|
||||
Supported at start of input (can be chained):
|
||||
:complete N -> completion mode with N tokens (default 64)
|
||||
:mask R -> random masking with ratio R in [0, 1]
|
||||
"""
|
||||
tokens = text.strip().split()
|
||||
i = 0
|
||||
mode = "random"
|
||||
completion_tokens = 0
|
||||
target_mask_ratio = None
|
||||
consumed = 0
|
||||
while i < len(tokens) and tokens[i].startswith(":"):
|
||||
cmd = tokens[i]
|
||||
i += 1
|
||||
consumed = i
|
||||
if cmd == ":complete":
|
||||
mode = "completion"
|
||||
if i < len(tokens):
|
||||
try:
|
||||
completion_tokens = int(tokens[i])
|
||||
i += 1
|
||||
consumed = i
|
||||
except Exception:
|
||||
completion_tokens = 64
|
||||
else:
|
||||
completion_tokens = 64
|
||||
elif cmd == ":mask":
|
||||
mode = "random"
|
||||
if i < len(tokens):
|
||||
try:
|
||||
target_mask_ratio = float(tokens[i])
|
||||
i += 1
|
||||
consumed = i
|
||||
except Exception:
|
||||
target_mask_ratio = None
|
||||
else:
|
||||
i -= 1
|
||||
consumed = i
|
||||
break
|
||||
|
||||
cleaned = " ".join(tokens[consumed:])
|
||||
|
||||
return mode, completion_tokens, target_mask_ratio, cleaned
|
||||
|
||||
|
||||
def run_diffusion(
|
||||
*,
|
||||
model,
|
||||
tokenizer,
|
||||
cfg: DictDefault,
|
||||
prompt: str,
|
||||
chat_template_str: str | None,
|
||||
mode: str = "random",
|
||||
target_mask_ratio: float | None = None,
|
||||
completion_tokens: int = 0,
|
||||
):
|
||||
"""Run a single diffusion generation and return a structured result dict."""
|
||||
if chat_template_str:
|
||||
batch = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
else:
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
mask_token_id = resolve_mask_token_id(tokenizer, cfg, allow_add=False)
|
||||
|
||||
seq = batch["input_ids"].to(cfg.device)
|
||||
gen_mode = "completion" if mode == "completion" else "random"
|
||||
comp_tokens = int(completion_tokens) if gen_mode == "completion" else 0
|
||||
|
||||
result = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
original_sequence=seq[:1],
|
||||
num_diffusion_steps=cfg.diffusion.num_diffusion_steps,
|
||||
temperature=cfg.diffusion.generation_temperature,
|
||||
mask_token_id=int(mask_token_id),
|
||||
mode=gen_mode, # type: ignore[arg-type]
|
||||
completion_tokens=comp_tokens,
|
||||
target_mask_ratio=target_mask_ratio,
|
||||
)
|
||||
|
||||
masked_text = result.get("masked") if isinstance(result, dict) else None
|
||||
mask_ratio = result.get("mask_ratio") if isinstance(result, dict) else None
|
||||
generated_ids = result.get("generated_ids") if isinstance(result, dict) else None
|
||||
masked_positions = (
|
||||
set(result.get("masked_positions") or []) if isinstance(result, dict) else set()
|
||||
)
|
||||
orig_ids = seq[0].detach().cpu().tolist()
|
||||
|
||||
return {
|
||||
"masked_text": masked_text,
|
||||
"mask_ratio": mask_ratio,
|
||||
"generated_ids": generated_ids,
|
||||
"masked_positions": masked_positions,
|
||||
"orig_ids": orig_ids,
|
||||
}
|
||||
|
||||
|
||||
def render_html(
|
||||
*,
|
||||
generated_ids: list[int] | None,
|
||||
orig_ids: list[int],
|
||||
masked_positions: set[int],
|
||||
tokenizer,
|
||||
) -> str:
|
||||
"""Render HTML visualizing diffusion outputs."""
|
||||
if not generated_ids:
|
||||
return "<pre>Generated:\n(no output)</pre>"
|
||||
|
||||
def _style_for(i: int, tid: int) -> str:
|
||||
if i in masked_positions:
|
||||
if i < len(orig_ids) and tid == orig_ids[i]:
|
||||
return "green"
|
||||
if i < len(orig_ids):
|
||||
return "red"
|
||||
return "normal"
|
||||
same = i < len(orig_ids) and tid == orig_ids[i]
|
||||
return "dim" if same else "normal"
|
||||
|
||||
# Group contiguous spans by style to reduce HTML size
|
||||
spans: list[tuple[str, int, int]] = []
|
||||
if generated_ids:
|
||||
cur = _style_for(0, generated_ids[0])
|
||||
start = 0
|
||||
for i in range(1, len(generated_ids)):
|
||||
s = _style_for(i, generated_ids[i])
|
||||
if s != cur:
|
||||
spans.append((cur, start, i))
|
||||
cur, start = s, i
|
||||
spans.append((cur, start, len(generated_ids)))
|
||||
|
||||
html_parts = []
|
||||
for style_name, a, b in spans:
|
||||
txt = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False)
|
||||
if style_name == "green":
|
||||
html_parts.append(f'<span style="color:#2e7d32">{txt}</span>')
|
||||
elif style_name == "red":
|
||||
html_parts.append(f'<span style="color:#c62828">{txt}</span>')
|
||||
elif style_name == "dim":
|
||||
html_parts.append(f'<span style="opacity:0.6">{txt}</span>')
|
||||
else:
|
||||
html_parts.append(txt)
|
||||
|
||||
legend = (
|
||||
'<div style="font-size:0.9em;margin-bottom:4px">'
|
||||
'<span style="color:#2e7d32">correct</span>, '
|
||||
'<span style="color:#c62828">incorrect</span>, '
|
||||
'<span style="opacity:0.6">unchanged</span>'
|
||||
"</div>"
|
||||
)
|
||||
|
||||
return (
|
||||
legend
|
||||
+ '<pre style="white-space:pre-wrap">Generated:\n'
|
||||
+ "".join(html_parts)
|
||||
+ "</pre>"
|
||||
)
|
||||
|
||||
|
||||
def launch_diffusion_gradio_ui(
|
||||
*,
|
||||
model,
|
||||
tokenizer,
|
||||
cfg: DictDefault,
|
||||
prompter_module=None,
|
||||
chat_template_str: str | None = None,
|
||||
):
|
||||
"""Build and launch a simple Gradio UI for diffusion inference."""
|
||||
with gr.Blocks(
|
||||
title=cfg.get("gradio_title", "Axolotl Diffusion Interface")
|
||||
) as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
## Axolotl Diffusion Inference
|
||||
- Mode "Random" masks tokens at a target ratio and fills them.
|
||||
- Mode "Completion" appends N masked tokens at the end and fills them.
|
||||
"""
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
mode = gr.Radio(
|
||||
choices=["random", "completion"],
|
||||
value="random",
|
||||
label="Mode",
|
||||
)
|
||||
mask_ratio = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
step=0.05,
|
||||
value=0.4,
|
||||
label="Mask ratio (random mode)",
|
||||
interactive=True,
|
||||
)
|
||||
completion_tokens = gr.Number(
|
||||
value=64,
|
||||
precision=0,
|
||||
label="Completion tokens (completion mode)",
|
||||
interactive=True,
|
||||
visible=False,
|
||||
)
|
||||
|
||||
instruction = gr.Textbox(label="Instruction", lines=6)
|
||||
run_btn = gr.Button("Generate")
|
||||
|
||||
masked_preview = gr.Textbox(label="Masked preview", lines=6)
|
||||
html_out = gr.HTML(label="Generated")
|
||||
|
||||
def _toggle_controls(selected_mode: str):
|
||||
return (
|
||||
gr.update(visible=(selected_mode == "random")),
|
||||
gr.update(visible=(selected_mode == "completion")),
|
||||
)
|
||||
|
||||
mode.change(
|
||||
_toggle_controls,
|
||||
inputs=[mode],
|
||||
outputs=[mask_ratio, completion_tokens],
|
||||
)
|
||||
|
||||
def _gen(instruction_text: str, selected_mode: str, mratio: float, ctoks: int):
|
||||
if not instruction_text:
|
||||
return "", "<pre>Generated:\n(no output)</pre>"
|
||||
|
||||
if prompter_module:
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(
|
||||
instruction=instruction_text.strip("\n")
|
||||
)
|
||||
)
|
||||
else:
|
||||
prompt = instruction_text.strip()
|
||||
|
||||
info = run_diffusion(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
prompt=prompt,
|
||||
chat_template_str=chat_template_str,
|
||||
mode=selected_mode,
|
||||
target_mask_ratio=mratio if selected_mode == "random" else None,
|
||||
completion_tokens=int(ctoks) if selected_mode == "completion" else 0,
|
||||
)
|
||||
|
||||
masked_text = info.get("masked_text")
|
||||
mask_ratio_val = info.get("mask_ratio")
|
||||
generated_ids = info.get("generated_ids")
|
||||
masked_positions = info.get("masked_positions") or set()
|
||||
orig_ids = info.get("orig_ids") or []
|
||||
|
||||
preview = (
|
||||
f"Masked ({mask_ratio_val:.1%}):\n{masked_text}"
|
||||
if masked_text is not None and mask_ratio_val is not None
|
||||
else ""
|
||||
)
|
||||
html = render_html(
|
||||
generated_ids=generated_ids,
|
||||
orig_ids=orig_ids,
|
||||
masked_positions=masked_positions,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
return preview, html
|
||||
|
||||
run_btn.click(
|
||||
_gen,
|
||||
inputs=[instruction, mode, mask_ratio, completion_tokens],
|
||||
outputs=[masked_preview, html_out],
|
||||
)
|
||||
|
||||
demo.queue().launch(
|
||||
show_api=False,
|
||||
share=cfg.get("gradio_share", True),
|
||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||
server_port=cfg.get("gradio_server_port", None),
|
||||
)
|
||||
142
src/axolotl/cli/utils/fetch.py
Normal file
142
src/axolotl/cli/utils/fetch.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Utilities for axolotl fetch CLI command."""
|
||||
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import requests
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _download_file(
|
||||
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Download a single file and return its processing status.
|
||||
|
||||
Args:
|
||||
file_info: Tuple of (file_path, remote_sha).
|
||||
raw_base_url: Base URL for raw GitHub content.
|
||||
dest_path: Local destination directory.
|
||||
dir_prefix: Directory prefix to filter files.
|
||||
|
||||
Returns:
|
||||
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
|
||||
"""
|
||||
file_path, remote_sha = file_info
|
||||
raw_url = f"{raw_base_url}/{file_path}"
|
||||
dest_file = dest_path / file_path.split(dir_prefix)[-1]
|
||||
|
||||
# Check if file exists and needs updating
|
||||
if dest_file.exists():
|
||||
with open(dest_file, "rb") as file:
|
||||
content = file.read()
|
||||
# Calculate git blob SHA
|
||||
blob = b"blob " + str(len(content)).encode() + b"\0" + content
|
||||
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
|
||||
|
||||
if local_sha == remote_sha:
|
||||
print(f"Skipping {file_path} (unchanged)")
|
||||
return file_path, "unchanged"
|
||||
|
||||
print(f"Updating {file_path}")
|
||||
status = "updated"
|
||||
else:
|
||||
print(f"Downloading {file_path}")
|
||||
status = "new"
|
||||
|
||||
# Create directories if needed
|
||||
dest_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download and save file
|
||||
try:
|
||||
response = requests.get(raw_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(dest_file, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
return file_path, status
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error downloading {file_path}: {str(request_error)}")
|
||||
return file_path, "error"
|
||||
|
||||
|
||||
def fetch_from_github(
|
||||
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
|
||||
) -> None:
|
||||
"""
|
||||
Sync files from a specific directory in the GitHub repository.
|
||||
Only downloads files that don't exist locally or have changed.
|
||||
|
||||
Args:
|
||||
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
|
||||
'deepspeed_configs/').
|
||||
dest_dir: Local destination directory.
|
||||
max_workers: Maximum number of concurrent downloads.
|
||||
"""
|
||||
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
|
||||
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
|
||||
|
||||
# Get repository tree with timeout
|
||||
response = requests.get(api_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
tree = json.loads(response.text)
|
||||
|
||||
# Filter for files and get their SHA
|
||||
files = {
|
||||
item["path"]: item["sha"]
|
||||
for item in tree["tree"]
|
||||
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
|
||||
}
|
||||
|
||||
if not files:
|
||||
raise click.ClickException(f"No files found in {dir_prefix}")
|
||||
|
||||
# Default destination directory is the last part of dir_prefix
|
||||
default_dest = Path(dir_prefix.rstrip("/"))
|
||||
dest_path = Path(dest_dir) if dest_dir else default_dest
|
||||
|
||||
# Keep track of processed files for summary
|
||||
files_processed: dict[str, list[str]] = {
|
||||
"new": [],
|
||||
"updated": [],
|
||||
"unchanged": [],
|
||||
"error": [],
|
||||
}
|
||||
|
||||
# Process files in parallel using ThreadPoolExecutor
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_file = {
|
||||
executor.submit(
|
||||
_download_file,
|
||||
(file_path, remote_sha),
|
||||
raw_base_url,
|
||||
dest_path,
|
||||
dir_prefix,
|
||||
): file_path
|
||||
for file_path, remote_sha in files.items()
|
||||
}
|
||||
|
||||
# Process completed tasks as they finish
|
||||
for future in concurrent.futures.as_completed(future_to_file):
|
||||
file_path = future_to_file[future]
|
||||
try:
|
||||
file_path, status = future.result()
|
||||
files_processed[status].append(file_path)
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error processing {file_path}: {str(request_error)}")
|
||||
files_processed["error"].append(file_path)
|
||||
|
||||
# Log summary
|
||||
LOG.info("\nSync Summary:")
|
||||
LOG.info(f"New files: {len(files_processed['new'])}")
|
||||
LOG.info(f"Updated files: {len(files_processed['updated'])}")
|
||||
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
|
||||
if files_processed["error"]:
|
||||
LOG.info(f"Failed files: {len(files_processed['error'])}")
|
||||
52
src/axolotl/cli/utils/load.py
Normal file
52
src/axolotl/cli/utils/load.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Utilities for model, tokenizer, etc. loading."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.loaders.model import ModelLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
) -> tuple[
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
|
||||
ProcessorMixin | None,
|
||||
]:
|
||||
"""
|
||||
Helper function for loading a model, tokenizer, and processor specified in the
|
||||
given `axolotl` config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
inference: Boolean denoting inference mode.
|
||||
|
||||
Returns:
|
||||
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
|
||||
"""
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
LOG.info("loading model...")
|
||||
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
|
||||
model, _ = model_loader.load()
|
||||
|
||||
processor = None
|
||||
if cfg.is_multimodal:
|
||||
LOG.info("loading processor...")
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
return model, tokenizer, processor
|
||||
@@ -3,11 +3,12 @@
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from typing import Any
|
||||
|
||||
|
||||
def generate_sweep_configs(
|
||||
base_config: dict[str, list], sweeps_config: dict[str, list]
|
||||
) -> list[dict[str, list]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Recursively generates all possible configurations by applying sweeps to the base config.
|
||||
|
||||
@@ -48,7 +49,10 @@ def generate_sweep_configs(
|
||||
new_config = {}
|
||||
# new_config = deepcopy(base_config)
|
||||
# Combine regular parameters with paired parameters
|
||||
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
|
||||
full_combo = {
|
||||
**dict(zip(param_names, reg_combo, strict=False)),
|
||||
**paired_set,
|
||||
}
|
||||
for param_name, param_value in full_combo.items():
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
@@ -57,7 +61,7 @@ def generate_sweep_configs(
|
||||
# If no paired values, just use regular combinations
|
||||
# new_config = deepcopy(base_config)
|
||||
new_config = {}
|
||||
for param_name, param_value in zip(param_names, reg_combo):
|
||||
for param_name, param_value in zip(param_names, reg_combo, strict=False):
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
225
src/axolotl/cli/utils/train.py
Normal file
225
src/axolotl/cli/utils/train.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Utilities for axolotl train CLI command."""
|
||||
|
||||
import os
|
||||
import subprocess # nosec
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterator, Literal
|
||||
|
||||
import yaml
|
||||
|
||||
from axolotl.cli.utils.sweeps import generate_sweep_configs
|
||||
|
||||
|
||||
def _add_default_rdzv_args(launcher_args: list[str]) -> list[str]:
|
||||
"""
|
||||
Add default RDZV arguments if rdzv_endpoint is set but rdzv_backend/rdzv_id are missing.
|
||||
|
||||
Args:
|
||||
launcher_args: List of launcher arguments
|
||||
|
||||
Returns:
|
||||
Updated launcher args with defaults added if needed
|
||||
"""
|
||||
args = launcher_args.copy()
|
||||
|
||||
# Check if rdzv_endpoint is present
|
||||
has_rdzv_endpoint = any("--rdzv_endpoint" in arg for arg in args)
|
||||
|
||||
if has_rdzv_endpoint:
|
||||
# Check if rdzv_backend is already provided
|
||||
has_rdzv_backend = any("--rdzv_backend" in arg for arg in args)
|
||||
if not has_rdzv_backend:
|
||||
args.extend(["--rdzv_backend", "c10d"])
|
||||
|
||||
# Check if rdzv_id is already provided
|
||||
has_rdzv_id = any("--rdzv_id" in arg for arg in args)
|
||||
if not has_rdzv_id:
|
||||
import uuid
|
||||
|
||||
args.extend(["--rdzv_id", str(uuid.uuid4())[:8]])
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Build command list from base command and options.
|
||||
|
||||
Args:
|
||||
base_cmd: Command without options.
|
||||
options: Options to parse and append to base command.
|
||||
|
||||
Returns:
|
||||
List of strings giving shell command.
|
||||
"""
|
||||
cmd = base_cmd.copy()
|
||||
|
||||
for key, value in options.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
key = key.replace("_", "-")
|
||||
cmd.append(f"--{key}={value}")
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]:
|
||||
"""
|
||||
Generate list of configuration files to process. Yields a tuple of the configuration file name and a boolean indicating
|
||||
whether this is a group of configurations (i.e., a sweep).
|
||||
|
||||
Args:
|
||||
config: Base configuration file
|
||||
sweep: Sweep configuration file
|
||||
"""
|
||||
|
||||
if not sweep:
|
||||
yield config, False
|
||||
return
|
||||
|
||||
# Load sweep and base configurations
|
||||
with open(sweep, "r", encoding="utf-8") as fin:
|
||||
sweep_config: dict[str, list] = yaml.safe_load(fin)
|
||||
with open(config, "r", encoding="utf-8") as fin:
|
||||
base_config: dict[str, list] = yaml.safe_load(fin)
|
||||
|
||||
# Generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
is_group = len(permutations) > 1
|
||||
base_output_dir = base_config.get("output_dir", "./model-out")
|
||||
for idx, permutation in enumerate(permutations, start=1):
|
||||
permutation_dir = Path(permutation.get("output_dir", base_output_dir))
|
||||
permutation_id = f"sweep{idx:04d}"
|
||||
permutation["output_dir"] = str(permutation_dir / permutation_id)
|
||||
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
suffix=".yaml",
|
||||
delete=False,
|
||||
encoding="utf-8",
|
||||
)
|
||||
yaml.dump(permutation, temp_file)
|
||||
temp_file.close()
|
||||
yield temp_file.name, is_group
|
||||
|
||||
|
||||
def launch_training(
|
||||
cfg_file: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] | None,
|
||||
cloud: str | None,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training with the given configuration."""
|
||||
launcher_args = launcher_args or []
|
||||
|
||||
if cloud:
|
||||
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
|
||||
elif launcher:
|
||||
if launcher == "accelerate":
|
||||
_launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||
elif launcher == "torchrun":
|
||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||
elif launcher == "python":
|
||||
_launch_python_training(cfg_file, kwargs)
|
||||
elif launcher is None:
|
||||
# handle ray train launch
|
||||
_launch_python_training(cfg_file, kwargs)
|
||||
|
||||
|
||||
def _launch_cloud_training(
|
||||
cloud: str,
|
||||
cfg_file: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] | None,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Execute training via cloud launcher."""
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
launcher_args = launcher_args or []
|
||||
cwd = os.getcwd() if launcher else None
|
||||
|
||||
do_cli_train(
|
||||
cloud_config=cloud,
|
||||
config=cfg_file,
|
||||
launcher=launcher or "accelerate",
|
||||
launcher_args=launcher_args,
|
||||
cwd=cwd,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _launch_accelerate_training(
|
||||
cfg_file: str,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training via accelerate launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
internal_launcher_args = []
|
||||
|
||||
# Extract launcher-specific arguments from kwargs (legacy support)
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port")
|
||||
internal_launcher_args.extend(["--main_process_port", str(main_process_port)])
|
||||
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes")
|
||||
internal_launcher_args.extend(["--num_processes", str(num_processes)])
|
||||
|
||||
# Combine internal args with user-provided launcher args
|
||||
all_launcher_args = internal_launcher_args + launcher_args
|
||||
|
||||
base_cmd = (
|
||||
["accelerate", "launch"] + all_launcher_args + ["-m", "axolotl.cli.train"]
|
||||
)
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
if use_exec:
|
||||
# make sure to flush stdout and stderr before replacing the process
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||
else:
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_torchrun_training(
|
||||
cfg_file: str,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training via torchrun launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
|
||||
# Add default RDZV arguments if rdzv_endpoint is set
|
||||
launcher_args = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
base_cmd = ["torchrun"] + launcher_args + ["-m", "axolotl.cli.train"]
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
if use_exec:
|
||||
# make sure to flush stdout and stderr before replacing the process
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||
else:
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
|
||||
"""Execute training via python launcher."""
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
do_cli(config=cfg_file, **kwargs)
|
||||
@@ -2,12 +2,10 @@
|
||||
CLI to start the vllm server for online RL
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import trl
|
||||
from trl.scripts.vllm_serve import ScriptArguments
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
@@ -37,16 +35,22 @@ def do_vllm_serve(
|
||||
Returns:
|
||||
process_id: the process id of the started VLLM server
|
||||
"""
|
||||
patch_vllm_worker()
|
||||
cfg = load_cfg(config)
|
||||
model = cfg.base_model
|
||||
|
||||
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
|
||||
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
|
||||
vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
|
||||
tensor_parallel_size = 1
|
||||
data_parallel_size = 1
|
||||
|
||||
tensor_parallel_size = (
|
||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||
)
|
||||
if cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size:
|
||||
tensor_parallel_size = (
|
||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||
)
|
||||
if cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size:
|
||||
data_parallel_size = (
|
||||
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
|
||||
)
|
||||
host = cli_args.get("host") or cfg.vllm.host
|
||||
port = cli_args.get("port") or cfg.vllm.port
|
||||
gpu_memory_utilization = (
|
||||
@@ -64,10 +68,10 @@ def do_vllm_serve(
|
||||
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
|
||||
)
|
||||
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
vllm_script_args = AxolotlScriptArguments(
|
||||
model=model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
data_parallel_size=data_parallel_size,
|
||||
host=host,
|
||||
port=port,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
@@ -78,63 +82,3 @@ def do_vllm_serve(
|
||||
enable_reasoning=enable_reasoning,
|
||||
)
|
||||
vllm_serve_main(vllm_script_args)
|
||||
|
||||
|
||||
def patch_vllm_worker():
|
||||
from multiprocessing.connection import Connection
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
def llm_worker(
|
||||
script_args: AxolotlScriptArguments,
|
||||
data_parallel_rank: int,
|
||||
master_port: int,
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
# Set required environment variables for DP to work with vLLM
|
||||
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
|
||||
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
||||
|
||||
llm = LLM(
|
||||
model=script_args.model,
|
||||
revision=script_args.revision,
|
||||
tensor_parallel_size=script_args.tensor_parallel_size,
|
||||
gpu_memory_utilization=script_args.gpu_memory_utilization,
|
||||
enforce_eager=script_args.enforce_eager,
|
||||
dtype=script_args.dtype,
|
||||
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
|
||||
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
|
||||
# This is particularly useful here because we generate completions from the same prompts.
|
||||
enable_prefix_caching=script_args.enable_prefix_caching,
|
||||
kv_cache_dtype=script_args.kv_cache_dtype,
|
||||
max_model_len=script_args.max_model_len,
|
||||
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
|
||||
enable_reasoning=script_args.enable_reasoning,
|
||||
reasoning_parser=script_args.reasoning_parser,
|
||||
)
|
||||
|
||||
# Send ready signal to parent process
|
||||
connection.send({"status": "ready"})
|
||||
|
||||
while True:
|
||||
# Wait for commands from the parent process
|
||||
try:
|
||||
command = connection.recv()
|
||||
except KeyboardInterrupt:
|
||||
llm.collective_rpc(method="close_communicator")
|
||||
break
|
||||
|
||||
# Handle commands
|
||||
if command["type"] in ["call", "fire_and_forget"]:
|
||||
method_name = command["method"]
|
||||
args, kwargs = command.get("args", ()), command.get("kwargs", {})
|
||||
method = getattr(llm, method_name)
|
||||
result = method(*args, **kwargs)
|
||||
if command["type"] == "call":
|
||||
connection.send(result)
|
||||
elif command["type"] == "shutdown":
|
||||
break
|
||||
|
||||
trl.scripts.vllm_serve.llm_worker = llm_worker
|
||||
|
||||
@@ -13,4 +13,6 @@ MOE_ARCH_BLOCK = {
|
||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||
"deepseek_v2": "DeepseekV2MoE",
|
||||
"gpt_oss": "GptOssDecoderLayer",
|
||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
"""
|
||||
Various shared constants
|
||||
"""
|
||||
"""Various shared constants"""
|
||||
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
|
||||
@@ -3,16 +3,14 @@
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # noqa: F401
|
||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
@@ -31,16 +29,7 @@ class TrainDatasetMeta:
|
||||
|
||||
|
||||
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
||||
"""
|
||||
Randomly sample `num_samples` samples from `dataset`.
|
||||
|
||||
Args:
|
||||
dataset: Dataset.
|
||||
num_samples: Number of samples to return.
|
||||
|
||||
Returns:
|
||||
Random sample (with replacement) of examples in `dataset`.
|
||||
"""
|
||||
"""Randomly sample `num_samples` samples with replacement from `dataset`."""
|
||||
return dataset.select(
|
||||
[random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec
|
||||
)
|
||||
@@ -53,55 +42,50 @@ def load_datasets(
|
||||
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
|
||||
debug: bool = False,
|
||||
) -> TrainDatasetMeta:
|
||||
"""
|
||||
Loads one or more training or evaluation datasets, calling
|
||||
`axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information.
|
||||
"""Loads one or more training or evaluation datasets, calling
|
||||
`axolotl.utils.data.prepare_datasets`. Optionally, logs out debug information.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Command-specific CLI arguments.
|
||||
debug: Whether to print out tokenization of sample
|
||||
debug: Whether to print out tokenization of sample. This is duplicated in
|
||||
`cfg` and `cli_args`, but is kept due to use in our Colab notebooks.
|
||||
|
||||
Returns:
|
||||
Dataclass with fields for training and evaluation datasets and the computed
|
||||
`total_num_steps`.
|
||||
`total_num_steps`.
|
||||
"""
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
||||
preprocess_iterable = (
|
||||
cli_args
|
||||
and hasattr(cli_args, "iterable")
|
||||
and cli_args.iterable is not None
|
||||
and cli_args.iterable
|
||||
)
|
||||
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
|
||||
cfg,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
|
||||
if ( # pylint: disable=too-many-boolean-expressions
|
||||
cli_args
|
||||
and (
|
||||
cli_args.debug
|
||||
or cfg.debug
|
||||
or cli_args.debug_text_only
|
||||
or int(cli_args.debug_num_examples) > 0
|
||||
)
|
||||
) or debug:
|
||||
if (
|
||||
cfg.debug
|
||||
or getattr(cli_args, "debug", False)
|
||||
or getattr(cli_args, "debug_text_only", False)
|
||||
or getattr(cli_args, "debug_num_examples", 0) > 0
|
||||
or debug
|
||||
):
|
||||
LOG.info("check_dataset_labels...")
|
||||
|
||||
num_examples = cli_args.debug_num_examples if cli_args else 1
|
||||
text_only = cli_args.debug_text_only if cli_args else False
|
||||
train_samples = sample_dataset(train_dataset, num_examples)
|
||||
check_dataset_labels(
|
||||
train_samples,
|
||||
tokenizer,
|
||||
num_examples=num_examples,
|
||||
text_only=text_only,
|
||||
)
|
||||
try:
|
||||
train_samples = sample_dataset(train_dataset, num_examples)
|
||||
check_dataset_labels(
|
||||
train_samples,
|
||||
tokenizer,
|
||||
num_examples=num_examples,
|
||||
text_only=text_only,
|
||||
)
|
||||
except AttributeError:
|
||||
# can't sample iterable datasets
|
||||
pass
|
||||
|
||||
LOG.info("printing prompters...")
|
||||
for prompter in prompters:
|
||||
@@ -116,13 +100,10 @@ def load_datasets(
|
||||
|
||||
@send_errors
|
||||
def load_preference_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
||||
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None
|
||||
) -> TrainDatasetMeta:
|
||||
"""
|
||||
Loads one or more training or evaluation datasets for RL training using paired
|
||||
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
|
||||
"""Loads one or more training or evaluation datasets for RL training using paired
|
||||
preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`.
|
||||
Optionally, logs out debug information.
|
||||
|
||||
Args:
|
||||
@@ -133,23 +114,28 @@ def load_preference_datasets(
|
||||
Dataclass with fields for training and evaluation datasets and the computed
|
||||
`total_num_steps`.
|
||||
"""
|
||||
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||
total_num_steps: Optional[int] = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
if cfg.rl is RLType.GRPO:
|
||||
total_num_steps = None
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
total_num_steps: int | None = None
|
||||
if cfg.rl is not RLType.GRPO:
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
|
||||
if ((cli_args and cli_args.debug) or cfg.debug) and cfg.rl != RLType.ORPO:
|
||||
LOG.info("check_dataset_labels...")
|
||||
|
||||
num_examples = cli_args.debug_num_examples if cli_args else 1
|
||||
text_only = cli_args.debug_text_only if cli_args else False
|
||||
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
|
||||
train_samples = sample_dataset(train_dataset, num_examples)
|
||||
check_dataset_labels(
|
||||
train_samples,
|
||||
tokenizer,
|
||||
num_examples=cli_args.debug_num_examples,
|
||||
text_only=cli_args.debug_text_only,
|
||||
dataset=train_samples,
|
||||
tokenizer=tokenizer,
|
||||
num_examples=num_examples,
|
||||
text_only=text_only,
|
||||
rl_mode=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -67,9 +67,7 @@ class JsonToJsonlConverter:
|
||||
self.json_parser = json_parser
|
||||
self.jsonl_serializer = jsonl_serializer
|
||||
|
||||
def convert(
|
||||
self, input_file_path, output_file_path
|
||||
): # pylint: disable=unused-argument
|
||||
def convert(self, input_file_path, output_file_path):
|
||||
content = self.file_reader.read(input_file_path)
|
||||
data = self.json_parser.parse(content)
|
||||
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations
|
||||
|
||||
158
src/axolotl/core/attention/flex_block_mask.py
Normal file
158
src/axolotl/core/attention/flex_block_mask.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
monkeypatch for flex + packing
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
from transformers import Cache, PretrainedConfig
|
||||
from transformers.masking_utils import (
|
||||
ALL_MASK_ATTENTION_FUNCTIONS,
|
||||
_preprocess_mask_arguments,
|
||||
and_masks,
|
||||
causal_mask_function,
|
||||
or_masks,
|
||||
)
|
||||
from transformers.utils import is_torch_greater_or_equal
|
||||
|
||||
_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
|
||||
|
||||
|
||||
def create_causal_mask(
|
||||
config: PretrainedConfig,
|
||||
input_embeds: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Optional[Cache],
|
||||
or_mask_function: Optional[Callable] = None,
|
||||
and_mask_function: Optional[Callable] = None,
|
||||
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
||||
"""
|
||||
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
|
||||
has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
|
||||
to what is needed in the `modeling_xxx.py` files).
|
||||
|
||||
Args:
|
||||
config (`PretrainedConfig`):
|
||||
The model config.
|
||||
input_embeds (`torch.Tensor`):
|
||||
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
|
||||
batch size, query length and dtype.
|
||||
attention_mask (`torch.Tensor`, optional):
|
||||
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
|
||||
It can also be an already prepared 4D mask, in which case it is returned as-is.
|
||||
cache_position (`torch.Tensor`):
|
||||
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
|
||||
past_key_values (`Cache`, optional):
|
||||
The past key values, if we use a cache.
|
||||
or_mask_function (`Callable`, optional):
|
||||
An optional mask function to combine with the causal mask function (by doing the union of both). This is
|
||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
||||
and_mask_function (`Callable`, optional):
|
||||
An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
|
||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
||||
"""
|
||||
# If we have an HybridCache structure, here we want to create the mask for the full layers
|
||||
if (
|
||||
past_key_values
|
||||
and hasattr(past_key_values, "is_sliding")
|
||||
and False in past_key_values.is_sliding
|
||||
):
|
||||
layer_idx = past_key_values.is_sliding.index(False)
|
||||
else:
|
||||
layer_idx = 0
|
||||
|
||||
original_attention_mask = (
|
||||
None
|
||||
if attention_mask is None
|
||||
else attention_mask.clone().to(cache_position.device)
|
||||
)
|
||||
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
|
||||
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
|
||||
)
|
||||
if early_exit:
|
||||
return attention_mask
|
||||
|
||||
batch_size, total_seq_len = cache_position.shape
|
||||
key_length = total_seq_len
|
||||
document_ids = torch.nn.functional.pad(
|
||||
original_attention_mask, value=0, pad=(0, key_length)
|
||||
)
|
||||
|
||||
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
|
||||
if attention_mask is not None:
|
||||
|
||||
def causal_doc_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
"""
|
||||
Defines the logic of a block causal mask by combining both a standard causal mask
|
||||
and a block diagonal document mask.
|
||||
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
|
||||
for an illustration.
|
||||
"""
|
||||
causal_mask_ = q_idx >= kv_idx # not valid when decoding
|
||||
document_mask = (
|
||||
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
|
||||
)
|
||||
final_mask = causal_mask_ & document_mask
|
||||
return final_mask
|
||||
|
||||
mask_factory_function = causal_doc_mask_mod
|
||||
else:
|
||||
mask_factory_function = causal_mask_function
|
||||
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
|
||||
|
||||
# Do not allow skip if we are compiling (this is to match BC)
|
||||
allow_is_causal_skip = (
|
||||
not past_key_values.is_compileable if past_key_values is not None else True
|
||||
)
|
||||
|
||||
# Allow slight deviations from causal mask
|
||||
if or_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError(
|
||||
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
||||
)
|
||||
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
if and_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError(
|
||||
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
||||
)
|
||||
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
|
||||
# We now create the mask
|
||||
causal_mask = mask_interface(
|
||||
batch_size=batch_size,
|
||||
cache_position=cache_position,
|
||||
kv_length=kv_length,
|
||||
kv_offset=kv_offset,
|
||||
mask_function=mask_factory_function,
|
||||
attention_mask=attention_mask,
|
||||
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
|
||||
dtype=dtype, # Additional kwarg for eager
|
||||
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
def patch_create_causal_mask(model_type):
|
||||
import transformers.masking_utils
|
||||
|
||||
transformers.masking_utils.create_causal_mask = create_causal_mask
|
||||
|
||||
if model_type:
|
||||
try:
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
module = __import__(module_path)
|
||||
module.create_causal_mask = create_causal_mask
|
||||
del sys.modules[module_path]
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(
|
||||
f"Could not import attention class for model_type: {model_type}. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
@@ -24,10 +24,8 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
)
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers import TrainerCallback
|
||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||
@@ -36,16 +34,17 @@ from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
GPUStatsCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveModelOnFirstStepCallback,
|
||||
)
|
||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||
from axolotl.utils.distributed import build_parallelism_config
|
||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
with suppress(ImportError):
|
||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||
import torch._dynamo
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
@@ -114,13 +113,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
||||
)
|
||||
|
||||
if self.cfg.profiler_steps:
|
||||
callbacks.append(
|
||||
PytorchProfilerCallback(
|
||||
steps_to_profile=self.cfg.profiler_steps,
|
||||
)
|
||||
)
|
||||
|
||||
if self.cfg.gc_steps:
|
||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||
|
||||
@@ -144,8 +136,16 @@ class TrainerBuilderBase(abc.ABC):
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.save_first_step:
|
||||
callbacks.append(SaveModelOnFirstStepCallback())
|
||||
|
||||
callbacks.append(GPUStatsCallback(cfg=self.cfg))
|
||||
if self.cfg.profiler_steps:
|
||||
callbacks.append(
|
||||
PytorchProfilerCallback(
|
||||
steps_to_profile=self.cfg.profiler_steps,
|
||||
profiler_steps_start=self.cfg.profiler_steps_start,
|
||||
)
|
||||
)
|
||||
|
||||
telemetry_manager = TelemetryManager.get_instance()
|
||||
if telemetry_manager.enabled:
|
||||
@@ -225,7 +225,9 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if self.cfg.bf16 == "full":
|
||||
training_args_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
|
||||
bf16 = self.cfg.bf16 or self.cfg.bfloat16
|
||||
bf16 = bf16 if bf16 is not None else False
|
||||
training_args_kwargs["bf16"] = bf16
|
||||
|
||||
def _configure_scheduler(self, training_args_kwargs: dict):
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
|
||||
@@ -262,33 +264,30 @@ class TrainerBuilderBase(abc.ABC):
|
||||
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
|
||||
|
||||
if self.cfg.optimizer == "muon":
|
||||
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
|
||||
from axolotl.contribs.mit.muon import (
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls = MuonOptimizerFactory
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "dion":
|
||||
from axolotl.contribs.mit.dion import (
|
||||
DionOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls = DionOptimizerFactory
|
||||
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
|
||||
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
_, device_mesh = build_parallelism_config(self.cfg)
|
||||
if device_mesh is not None:
|
||||
optimizer_kwargs["device_mesh"] = device_mesh
|
||||
elif self.cfg.optimizer == "optimi_adamw":
|
||||
from optimi import AdamW
|
||||
|
||||
optimizer_kwargs["foreach"] = False
|
||||
optimizer_cls = AdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_4bit":
|
||||
# TODO remove 20250401
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||
|
||||
optimizer_cls = AdamW4bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
|
||||
LOG.warning(
|
||||
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
|
||||
)
|
||||
elif self.cfg.optimizer == "ao_adamw_8bit":
|
||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||
|
||||
optimizer_cls = AdamW8bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||
|
||||
@@ -386,14 +385,16 @@ class TrainerBuilderBase(abc.ABC):
|
||||
)
|
||||
|
||||
# eval_strategy and eval_steps
|
||||
if not self.eval_dataset or self.cfg.val_set_size == 0:
|
||||
# do not eval if no eval_dataset or val_set_size=0
|
||||
if not self.eval_dataset and self.cfg.val_set_size == 0:
|
||||
# do not eval if no eval_dataset and val_set_size=0
|
||||
training_args_kwargs["eval_strategy"] = "no"
|
||||
elif self.cfg.eval_steps:
|
||||
training_args_kwargs["eval_strategy"] = "steps"
|
||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
training_args_kwargs["eval_on_start"] = True
|
||||
elif self.cfg.eval_strategy:
|
||||
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||
training_args_kwargs["eval_on_start"] = True
|
||||
|
||||
def _configure_reporting(self, training_args_kwargs: dict):
|
||||
report_to = []
|
||||
@@ -417,9 +418,8 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
def _configure_torch_compile(self, training_args_kwargs: dict):
|
||||
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
|
||||
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
||||
True
|
||||
)
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 256
|
||||
training_args_kwargs["torch_compile"] = self.cfg.torch_compile
|
||||
if self.cfg.torch_compile_backend:
|
||||
training_args_kwargs["torch_compile_backend"] = (
|
||||
@@ -428,8 +428,20 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if self.cfg.torch_compile_mode:
|
||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||
|
||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||
if self.cfg.accelerator_config:
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||
**self.cfg.accelerator_config
|
||||
)
|
||||
else:
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
|
||||
|
||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||
if self.cfg.gradient_checkpointing:
|
||||
if self.cfg.activation_offloading is True:
|
||||
# don't use the HF gradient checkpointing, manually wrap
|
||||
training_args_kwargs["gradient_checkpointing"] = False
|
||||
training_args_kwargs["activation_offloading"] = True
|
||||
elif self.cfg.gradient_checkpointing is not None:
|
||||
training_args_kwargs["gradient_checkpointing"] = (
|
||||
self.cfg.gradient_checkpointing
|
||||
)
|
||||
@@ -482,17 +494,30 @@ class TrainerBuilderBase(abc.ABC):
|
||||
"include_tokens_per_second",
|
||||
"weight_decay",
|
||||
"seed",
|
||||
"dion_momentum",
|
||||
"dion_rank_fraction",
|
||||
"dion_rank_multiple_of",
|
||||
"dataset_num_proc",
|
||||
]:
|
||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||
|
||||
arg_map = {
|
||||
"dion_learning_rate": "dion_lr",
|
||||
}
|
||||
for kwarg, cfg_arg in arg_map.items():
|
||||
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:
|
||||
training_args_kwargs[kwarg] = getattr(self.cfg, cfg_arg)
|
||||
|
||||
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
||||
training_args_kwargs["average_tokens_across_devices"] = False
|
||||
|
||||
if self.cfg.eval_batch_size:
|
||||
training_args_kwargs["per_device_eval_batch_size"] = (
|
||||
self.cfg.eval_batch_size
|
||||
)
|
||||
|
||||
training_args_kwargs["include_tkps"] = self.cfg.include_tkps
|
||||
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
|
||||
@@ -500,10 +525,15 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if self.cfg.reward_model or self.cfg.rl:
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
if self.cfg.fsdp_config or self.cfg.fsdp:
|
||||
training_args_kwargs["fsdp_config"] = self.cfg.fsdp_config
|
||||
training_args_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp else True
|
||||
|
||||
self._configure_reporting(training_args_kwargs)
|
||||
self._configure_hub_parameters(training_args_kwargs)
|
||||
self._configure_scheduler(training_args_kwargs)
|
||||
self._configure_optimizer(training_args_kwargs, trainer_kwargs)
|
||||
self._configure_torch_compile(training_args_kwargs)
|
||||
self._configure_accelerator_config(training_args_kwargs)
|
||||
|
||||
return training_args_kwargs, trainer_kwargs
|
||||
|
||||
@@ -10,6 +10,7 @@ import transformers
|
||||
from transformers import (
|
||||
DataCollatorWithFlattening,
|
||||
EarlyStoppingCallback,
|
||||
Trainer,
|
||||
)
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
@@ -19,12 +20,6 @@ from axolotl.core.trainers import (
|
||||
AxolotlPRMTrainer,
|
||||
AxolotlRewardTrainer,
|
||||
AxolotlTrainer,
|
||||
ReLoRATrainer,
|
||||
)
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlPRMConfig,
|
||||
AxolotlRewardConfig,
|
||||
AxolotlTrainingArguments,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
@@ -32,9 +27,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback
|
||||
from axolotl.processing_strategies import get_processing_strategy
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
LossWatchDogCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
causal_lm_bench_eval_callback_factory,
|
||||
colab_inference_post_train_callback,
|
||||
@@ -42,6 +35,7 @@ from axolotl.utils.callbacks import (
|
||||
)
|
||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.collators import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
@@ -50,6 +44,7 @@ from axolotl.utils.collators import (
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -63,17 +58,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
callbacks.append(EvalFirstStepCallback())
|
||||
|
||||
if self.cfg.relora_steps:
|
||||
if self.cfg.relora:
|
||||
callbacks.append(ReLoRACallback(self.cfg))
|
||||
|
||||
if (
|
||||
hasattr(self.model, "use_bettertransformer")
|
||||
and self.model.use_bettertransformer is True
|
||||
):
|
||||
callbacks.append(SaveBetterTransformerModelCallback())
|
||||
|
||||
# TODO: check if can move to base class
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
@@ -81,6 +69,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.qat:
|
||||
callbacks.append(QATCallback(self.cfg.qat))
|
||||
|
||||
if self.cfg.include_tkps:
|
||||
callbacks.append(
|
||||
TokensPerSecondCallback(
|
||||
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
|
||||
)
|
||||
)
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
@@ -130,33 +124,44 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return callbacks
|
||||
|
||||
def _get_trainer_cls(self):
|
||||
"""
|
||||
Gets the trainer class for the given configuration.
|
||||
"""
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
if trainer_cls:
|
||||
return trainer_cls
|
||||
if self.cfg.relora_steps:
|
||||
return ReLoRATrainer
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return AxolotlMambaTrainer
|
||||
if self.cfg.reward_model:
|
||||
return AxolotlRewardTrainer
|
||||
if self.cfg.process_reward_model:
|
||||
return AxolotlPRMTrainer
|
||||
|
||||
if self.cfg.trainer_cls:
|
||||
# override the trainer cls
|
||||
try:
|
||||
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
|
||||
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
|
||||
return trainer_cls
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ValueError(
|
||||
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
|
||||
) from e
|
||||
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlPRMConfig,
|
||||
AxolotlRewardConfig,
|
||||
AxolotlTrainingArguments,
|
||||
)
|
||||
|
||||
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
|
||||
total_num_steps
|
||||
)
|
||||
|
||||
if self.cfg.fsdp:
|
||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||
if self.cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = {
|
||||
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
|
||||
}
|
||||
|
||||
if self.cfg.adapter == "qlora":
|
||||
training_arguments_kwargs["qlora"] = True
|
||||
|
||||
@@ -243,14 +248,27 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||
|
||||
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||
training_arguments_kwargs["sample_packing_drop_attention_mask"] = bool(
|
||||
self.cfg.flash_attention
|
||||
or self.cfg.xformers_attention
|
||||
or self.cfg.flex_attention
|
||||
)
|
||||
training_arguments_kwargs["multipack_real_batches"] = (
|
||||
self.cfg.multipack_real_batches
|
||||
if self.cfg.multipack_real_batches is not None
|
||||
else not self.cfg.flash_attention
|
||||
else not (
|
||||
self.cfg.flash_attention
|
||||
or self.cfg.flex_attention
|
||||
or self.cfg.xformers_attention
|
||||
)
|
||||
)
|
||||
training_arguments_kwargs["eval_sample_packing"] = bool(
|
||||
self.cfg.eval_sample_packing
|
||||
)
|
||||
if self.cfg.sample_packing_sequentially is not None:
|
||||
training_arguments_kwargs["sample_packing_sequentially"] = (
|
||||
self.cfg.sample_packing_sequentially
|
||||
)
|
||||
if self.cfg.sample_packing_bin_size is not None:
|
||||
training_arguments_kwargs["sample_packing_bin_size"] = (
|
||||
self.cfg.sample_packing_bin_size
|
||||
@@ -264,20 +282,25 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.sample_packing_eff_est
|
||||
)
|
||||
|
||||
if self.cfg.relora_steps:
|
||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||
training_arguments_kwargs["relora_warmup_steps"] = (
|
||||
self.cfg.relora_warmup_steps
|
||||
)
|
||||
if self.cfg.relora_anneal_steps:
|
||||
training_arguments_kwargs["relora_anneal_steps"] = (
|
||||
self.cfg.relora_anneal_steps
|
||||
)
|
||||
if self.cfg.relora and self.cfg.jagged_restart_steps:
|
||||
if self.cfg.relora_prune_ratio:
|
||||
training_arguments_kwargs["relora_prune_ratio"] = (
|
||||
self.cfg.relora_prune_ratio
|
||||
)
|
||||
|
||||
if self.cfg.jagged_restart_steps:
|
||||
training_arguments_kwargs["jagged_restart_steps"] = (
|
||||
self.cfg.jagged_restart_steps
|
||||
)
|
||||
if self.cfg.jagged_restart_warmup_steps:
|
||||
training_arguments_kwargs["jagged_restart_warmup_steps"] = (
|
||||
self.cfg.jagged_restart_warmup_steps
|
||||
)
|
||||
if self.cfg.jagged_restart_anneal_steps:
|
||||
training_arguments_kwargs["jagged_restart_anneal_steps"] = (
|
||||
self.cfg.jagged_restart_anneal_steps
|
||||
)
|
||||
|
||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
|
||||
training_arguments_kwargs["lisa_step_interval"] = (
|
||||
@@ -303,48 +326,37 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.neftune_noise_alpha
|
||||
)
|
||||
|
||||
if self.cfg.accelerator_config:
|
||||
training_arguments_kwargs["accelerator_config"] = (
|
||||
self.cfg.accelerator_config
|
||||
)
|
||||
|
||||
if self.cfg.image_size:
|
||||
training_arguments_kwargs["image_size"] = self.cfg.image_size
|
||||
if self.cfg.image_resize_algorithm:
|
||||
training_arguments_kwargs["image_resize_algorithm"] = (
|
||||
self.cfg.image_resize_algorithm
|
||||
)
|
||||
if self.cfg.kd_ce_alpha is not None:
|
||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||
if self.cfg.kd_alpha is not None:
|
||||
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
|
||||
if self.cfg.kd_temperature is not None:
|
||||
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
||||
if self.cfg.kd_zscore_base_temp is not None:
|
||||
training_arguments_kwargs["kd_zscore_base_temp"] = (
|
||||
self.cfg.kd_zscore_base_temp
|
||||
)
|
||||
if self.cfg.kd_top_k_before_softmax is not None:
|
||||
training_arguments_kwargs["kd_top_k_before_softmax"] = (
|
||||
self.cfg.kd_top_k_before_softmax
|
||||
)
|
||||
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_training_args = plugin_manager.get_training_args(self.cfg)
|
||||
if plugin_training_args:
|
||||
training_arguments_kwargs.update(plugin_training_args)
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
if self.cfg.center_rewards_coefficient is not None:
|
||||
training_arguments_kwargs["center_rewards_coefficient"] = (
|
||||
self.cfg.center_rewards_coefficient
|
||||
)
|
||||
elif self.cfg.process_reward_model:
|
||||
training_args_cls = AxolotlPRMConfig
|
||||
else:
|
||||
training_args_cls = AxolotlTrainingArguments
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
training_args = training_args_cls(
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
training_args = self.hook_post_create_training_args(training_args)
|
||||
|
||||
# unset run_name so wandb sets up experiment names
|
||||
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
|
||||
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
|
||||
None
|
||||
)
|
||||
training_args.run_name = None
|
||||
|
||||
data_collator_kwargs = {
|
||||
"padding": True, # True/"longest" is the default
|
||||
@@ -354,7 +366,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
||||
self.cfg.sequence_len / multiple
|
||||
)
|
||||
else:
|
||||
elif self.cfg.pad_to_sequence_len is None:
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
||||
@@ -376,12 +388,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
**data_collator_kwargs,
|
||||
)
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "processing_class" in sig.parameters:
|
||||
if "processing_class" in sig.parameters or issubclass(trainer_cls, Trainer):
|
||||
trainer_kwargs["processing_class"] = self.tokenizer
|
||||
elif "tokenizer" in sig.parameters:
|
||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
|
||||
if (
|
||||
not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer])
|
||||
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
|
||||
and self.cfg.datasets is not None
|
||||
):
|
||||
trainer_kwargs["dataset_tags"] = [
|
||||
@@ -397,6 +410,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
**trainer_kwargs,
|
||||
)
|
||||
trainer = self.hook_post_create_trainer(trainer)
|
||||
# if the trainer has the `axolotl_cfg` property, set it
|
||||
if hasattr(trainer, "axolotl_cfg"):
|
||||
trainer.axolotl_cfg = self.cfg
|
||||
for callback in self.get_post_trainer_create_callbacks(trainer):
|
||||
trainer.add_callback(callback)
|
||||
|
||||
@@ -408,7 +424,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return trainer
|
||||
|
||||
def build_collator(
|
||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||
self,
|
||||
training_args, # type: "AxolotlTrainingArguments" # type: ignore
|
||||
is_eval=False,
|
||||
**kwargs,
|
||||
):
|
||||
if training_args.pretraining:
|
||||
if (
|
||||
@@ -416,7 +435,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
or self.cfg.micro_batch_size > 1
|
||||
):
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
return None
|
||||
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn):
|
||||
return None
|
||||
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||
@@ -437,7 +457,19 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
]
|
||||
]
|
||||
collator_args = [self.tokenizer]
|
||||
if self.cfg.reward_model:
|
||||
|
||||
collator_cls_and_kwargs = None
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs(
|
||||
self.cfg, is_eval=is_eval
|
||||
)
|
||||
|
||||
if collator_cls_and_kwargs:
|
||||
collator = collator_cls_and_kwargs[0]
|
||||
if kwargs and isinstance(kwargs, dict):
|
||||
kwargs.update(collator_cls_and_kwargs[1])
|
||||
elif self.cfg.reward_model:
|
||||
collator = RewardDataCollatorWithPadding
|
||||
elif use_batch_sampler_collator:
|
||||
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
||||
@@ -468,16 +500,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
collator_args.pop(0)
|
||||
kwargs.pop("pad_to_multiple_of", None)
|
||||
kwargs.pop("padding", None)
|
||||
elif self.cfg.kd_trainer:
|
||||
from axolotl.integrations.kd.collator import (
|
||||
DataCollatorForKD,
|
||||
KDBatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
|
||||
if self.cfg.sample_packing:
|
||||
collator = KDBatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
collator = DataCollatorForKD
|
||||
else:
|
||||
collator = DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
@@ -12,13 +12,10 @@ from axolotl.core.trainers import (
|
||||
from axolotl.core.trainers.dpo import DPOStrategy
|
||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlCPOConfig,
|
||||
AxolotlKTOConfig,
|
||||
AxolotlORPOConfig,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import ensure_dtype
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
@@ -31,6 +28,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
|
||||
if self.cfg.qat:
|
||||
callbacks.append(QATCallback(self.cfg.qat))
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
@@ -54,7 +54,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||
sequence_parallel=self.cfg.context_parallel_size > 1
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
|
||||
@@ -73,12 +73,28 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
if self.cfg.trainer_cls:
|
||||
# override the trainer cls
|
||||
try:
|
||||
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
|
||||
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ValueError(
|
||||
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
|
||||
) from e
|
||||
|
||||
return trainer_cls, trainer_cls_args
|
||||
|
||||
def _build_training_arguments(self, total_num_steps):
|
||||
"""
|
||||
Returns training_args and trainer_kwargs
|
||||
"""
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlCPOConfig,
|
||||
AxolotlKTOConfig,
|
||||
AxolotlORPOConfig,
|
||||
)
|
||||
|
||||
training_args_kwargs, trainer_kwargs = self._set_base_training_args(
|
||||
total_num_steps=total_num_steps
|
||||
)
|
||||
@@ -90,10 +106,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
training_args_kwargs["remove_unused_columns"] = False
|
||||
|
||||
# only rlhf
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
if self.cfg.trl and self.cfg.trl.beta is not None:
|
||||
training_args_kwargs["beta"] = self.cfg.trl.beta
|
||||
elif self.cfg.rl_beta is not None:
|
||||
@@ -108,6 +120,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.use_wandb:
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
else:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
|
||||
training_args_cls = None
|
||||
blocklist_args_kwargs = []
|
||||
if self.cfg.rl is RLType.SIMPO:
|
||||
@@ -117,10 +134,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.cpo_alpha is not None:
|
||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||
|
||||
# Handle when max_prompt_length == max_length from defaults
|
||||
# CPOTrainer requires strictly less than
|
||||
if (
|
||||
training_args_kwargs["max_prompt_length"]
|
||||
== training_args_kwargs["max_length"]
|
||||
):
|
||||
training_args_kwargs["max_prompt_length"] -= 1
|
||||
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
training_args_cls = AxolotlORPOConfig
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
training_args_cls = AxolotlKTOConfig
|
||||
@@ -132,9 +155,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kto_undesirable_weight or 1.0
|
||||
)
|
||||
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl is RLType.GRPO:
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
@@ -142,22 +162,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
if self.cfg.rl is RLType.IPO:
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
|
||||
# Not compatible with IPO
|
||||
if self.cfg.rl is RLType.DPO and self.cfg.dpo_label_smoothing:
|
||||
training_args_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
if self.cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs["use_logits_to_keep"] = (
|
||||
self.cfg.dpo_use_logits_to_keep
|
||||
)
|
||||
training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
@@ -165,16 +170,20 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if blocklist_key in training_args_kwargs:
|
||||
del training_args_kwargs[blocklist_key]
|
||||
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_training_args = plugin_manager.get_training_args(self.cfg)
|
||||
if plugin_training_args:
|
||||
training_args_kwargs.update(plugin_training_args)
|
||||
|
||||
training_args = training_args_cls(
|
||||
logging_first_step=True,
|
||||
**training_args_kwargs,
|
||||
)
|
||||
|
||||
# unset run_name so wandb sets up experiment names
|
||||
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
|
||||
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
|
||||
None
|
||||
)
|
||||
training_args.run_name = None
|
||||
|
||||
return training_args, trainer_kwargs
|
||||
|
||||
@@ -216,7 +225,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
callbacks=self.get_callbacks(),
|
||||
**trainer_kwargs,
|
||||
)
|
||||
if self.cfg.fsdp:
|
||||
if self.cfg.fsdp_config or self.cfg.fsdp:
|
||||
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
|
||||
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
|
||||
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
|
||||
@@ -226,21 +235,3 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
trainer.add_callback(callback)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
HF Factory class for PPO Trainer
|
||||
"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def build(self, total_num_steps):
|
||||
# TODO: build PPOConfig
|
||||
raise NotImplementedError("PPO trainer builder is not implemented yet.")
|
||||
|
||||
@@ -10,7 +10,7 @@ from .shared import wrap_tools
|
||||
|
||||
def format_message(
|
||||
message: Messages,
|
||||
message_index: Optional[int] = None, # pylint: disable=unused-argument
|
||||
message_index: Optional[int] = None,
|
||||
) -> Messages:
|
||||
if message.is_chat_formatted:
|
||||
return message
|
||||
|
||||
@@ -15,11 +15,11 @@ class MessageRoles(str, Enum):
|
||||
Message roles for the system, user, assistant, and tools
|
||||
"""
|
||||
|
||||
system = "system" # pylint: disable=invalid-name
|
||||
user = "user" # pylint: disable=invalid-name
|
||||
assistant = "assistant" # pylint: disable=invalid-name
|
||||
tool = "tool" # pylint: disable=invalid-name
|
||||
ipython = ( # pylint: disable=invalid-name
|
||||
system = "system"
|
||||
user = "user"
|
||||
assistant = "assistant"
|
||||
tool = "tool"
|
||||
ipython = (
|
||||
# for responses from builtin tools
|
||||
"ipython"
|
||||
)
|
||||
@@ -30,12 +30,12 @@ class MessageContentTypes(str, Enum):
|
||||
Message content types for text, image, audio, tool calls, and tool responses
|
||||
"""
|
||||
|
||||
special_token = "special_token" # pylint: disable=invalid-name # nosec B105
|
||||
text = "text" # pylint: disable=invalid-name
|
||||
image = "image" # pylint: disable=invalid-name
|
||||
audio = "audio" # pylint: disable=invalid-name
|
||||
tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant
|
||||
tool_response = "tool_response" # pylint: disable=invalid-name
|
||||
special_token = "special_token" # nosec B105
|
||||
text = "text"
|
||||
image = "image"
|
||||
audio = "audio"
|
||||
tool_call = "tool_call"
|
||||
tool_response = "tool_response"
|
||||
|
||||
|
||||
class SpecialToken(str, Enum):
|
||||
@@ -43,8 +43,8 @@ class SpecialToken(str, Enum):
|
||||
Special tokens for beginning of string and end of string
|
||||
"""
|
||||
|
||||
bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105
|
||||
eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105
|
||||
bos_token = "bos_token" # nosec B105
|
||||
eos_token = "eos_token" # nosec B105
|
||||
|
||||
|
||||
class ToolCallFunction(BaseModel):
|
||||
@@ -73,7 +73,7 @@ class ToolCallContents(BaseModel):
|
||||
|
||||
name: str
|
||||
arguments: dict[str, Union[str, int]]
|
||||
id: Optional[str] = None # pylint: disable=invalid-name
|
||||
id: Optional[str] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
data = {"name": self.name, "arguments": self.arguments}
|
||||
@@ -89,7 +89,7 @@ class ToolResponseContents(BaseModel):
|
||||
|
||||
name: str
|
||||
content: Union[str, dict[str, Union[str, int, float]]]
|
||||
id: Optional[str] = None # pylint: disable=invalid-name
|
||||
id: Optional[str] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
data = {"name": self.name, "content": self.content}
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
chat dataset module
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from datasets import Dataset
|
||||
@@ -41,14 +40,10 @@ class TokenizedChatDataset(Dataset):
|
||||
)
|
||||
return ex.tokenized(model_transform)
|
||||
|
||||
process_or_cpu_count: int = (
|
||||
process_count or os.cpu_count() # type: ignore[assignment]
|
||||
)
|
||||
num_proc = min(32, process_or_cpu_count)
|
||||
features = data.features.keys()
|
||||
tokenized_data = data.map(
|
||||
map_fn,
|
||||
num_proc=num_proc,
|
||||
num_proc=process_count,
|
||||
keep_in_memory=keep_in_memory,
|
||||
remove_columns=features,
|
||||
desc="Tokenizing Chats",
|
||||
|
||||
@@ -1,23 +1,17 @@
|
||||
"""
|
||||
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
|
||||
This module contains a function that builds a transform that takes a row from the
|
||||
dataset and converts it to a Chat.
|
||||
"""
|
||||
|
||||
from typing import Any, Mapping, Union
|
||||
from typing import Any, Mapping
|
||||
|
||||
|
||||
def chat_message_transform_builder( # pylint: disable=dangerous-default-value
|
||||
def chat_message_transform_builder(
|
||||
train_on_inputs=False,
|
||||
conversations_field: str = "conversations",
|
||||
message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role"
|
||||
message_field_content: Union[str, list[str]] = [
|
||||
"value",
|
||||
"text",
|
||||
"content",
|
||||
], # commonly "content"
|
||||
message_field_training: Union[str, list[str]] = [
|
||||
"train",
|
||||
"weight",
|
||||
], # commonly "weight"
|
||||
conversations_field: str = "messages",
|
||||
message_field_role: str | list[str] | None = None, # commonly "role"
|
||||
message_field_content: str | list[str] | None = None, # commonly "content"
|
||||
message_field_training: str | list[str] | None = None, # commonly "weight"
|
||||
):
|
||||
"""Builds a transform that takes a row from the dataset and converts it to a Chat
|
||||
|
||||
@@ -26,19 +20,25 @@ def chat_message_transform_builder( # pylint: disable=dangerous-default-value
|
||||
If True, the transform will train on the inputs. If False, the transform will train on the targets.
|
||||
Defaults to False.
|
||||
conversations_field (str, optional):
|
||||
The field name of the conversations. Defaults to "conversations".
|
||||
The field name of the conversations. Defaults to "messages".
|
||||
message_field_role (str | list[str], optional):
|
||||
The field name of the role. Defaults to "role".
|
||||
The field name of the role.
|
||||
message_field_content (str | list[str], optional):
|
||||
The field name of the message content. Defaults to "content".
|
||||
The field name of the message content.
|
||||
message_field_training (str | list[str], optional):
|
||||
The field name of the train/weight. Defaults to "weight".
|
||||
The field name of the train/weight.
|
||||
|
||||
Returns:
|
||||
Callable:
|
||||
A function that takes a list of conversations and returns a list of messages.
|
||||
"""
|
||||
|
||||
if message_field_training is None:
|
||||
message_field_training = ["train", "weight"]
|
||||
if message_field_content is None:
|
||||
message_field_content = ["value", "text", "content"]
|
||||
if message_field_role is None:
|
||||
message_field_role = ["role", "from"]
|
||||
message_field_role = (
|
||||
[message_field_role]
|
||||
if isinstance(message_field_role, str)
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
"""Init for axolotl.core.trainers"""
|
||||
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .relora import ReLoRATrainer
|
||||
from .trl import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
AxolotlPRMTrainer,
|
||||
AxolotlRewardTrainer,
|
||||
TRLPPOTrainer,
|
||||
)
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
"""Module for customized trainers"""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial, wraps
|
||||
from typing import Callable, Literal, Optional
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
|
||||
import datasets
|
||||
import safetensors
|
||||
import torch
|
||||
from accelerate.state import AcceleratorState
|
||||
from datasets import Dataset
|
||||
from peft import PeftModel
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
DataLoader,
|
||||
@@ -19,13 +20,19 @@ from torch.utils.data import (
|
||||
Sampler,
|
||||
SequentialSampler,
|
||||
)
|
||||
from transformers import Trainer
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from transformers import PreTrainedModel, Trainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.mixins import (
|
||||
ActivationOffloadingMixin,
|
||||
CheckpointSaveMixin,
|
||||
DistributedParallelMixin,
|
||||
OptimizerMixin,
|
||||
PackingMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
)
|
||||
@@ -33,17 +40,46 @@ from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
)
|
||||
from axolotl.utils import get_not_null
|
||||
from axolotl.utils.bench import get_gpu_memory_usage
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
REDUCTION_FNS = {
|
||||
"mean": torch.mean,
|
||||
"min": torch.min,
|
||||
"max": torch.max,
|
||||
"sum": torch.sum,
|
||||
}
|
||||
|
||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
|
||||
class AxolotlTrainer(
|
||||
PackingMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
RngLoaderMixin,
|
||||
CheckpointSaveMixin,
|
||||
ActivationOffloadingMixin,
|
||||
DistributedParallelMixin,
|
||||
Trainer,
|
||||
):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
tag_names = ["axolotl"]
|
||||
_axolotl_cfg: DictDefault | None = None
|
||||
|
||||
@property
|
||||
def axolotl_cfg(self):
|
||||
return self._axolotl_cfg
|
||||
|
||||
@axolotl_cfg.setter
|
||||
def axolotl_cfg(self, cfg):
|
||||
self._axolotl_cfg = cfg
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -59,24 +95,13 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
self._signature_columns = None # workaround for pylint
|
||||
|
||||
super().__init__(*_args, **kwargs)
|
||||
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
self._stored_metrics = defaultdict(
|
||||
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
|
||||
)
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
256
|
||||
)
|
||||
model = torch.compile(
|
||||
model,
|
||||
backend=self.args.torch_compile_backend,
|
||||
mode=self.args.torch_compile_mode,
|
||||
)
|
||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
|
||||
def _create_multipack_sampler(
|
||||
self, base_sampler: Sampler, dataset: Dataset
|
||||
) -> MultipackBatchSampler:
|
||||
@@ -101,7 +126,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
)
|
||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||
|
||||
return MultipackBatchSampler(
|
||||
sampler = MultipackBatchSampler(
|
||||
base_sampler,
|
||||
lengths=get_dataset_lengths(dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
@@ -111,11 +136,16 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
sequential=self.args.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
num_processes=self.args.dataset_num_proc,
|
||||
mp_start_method=self.args.sample_packing_mp_start_method or "fork",
|
||||
)
|
||||
|
||||
len(sampler)
|
||||
return sampler
|
||||
|
||||
def _get_train_sampler(
|
||||
self, train_dataset: Optional[Dataset] = None
|
||||
) -> Optional[Sampler]:
|
||||
self, train_dataset: Dataset | None = None
|
||||
) -> Sampler | None:
|
||||
"""
|
||||
Helper method to get the sampler for training. Handles cases for sample packing
|
||||
and curriculum sampling (sequential).
|
||||
@@ -124,16 +154,22 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
If the dataset is non-empty, a sampler is returned, the type of which
|
||||
depends on the passed training args.
|
||||
"""
|
||||
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L969C1-L972C24
|
||||
if train_dataset is None:
|
||||
train_dataset = self.train_dataset
|
||||
if train_dataset is None or not has_length(train_dataset):
|
||||
return None
|
||||
|
||||
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
||||
|
||||
# Determine the base sampler first
|
||||
if self.args.curriculum_sampling:
|
||||
base_sampler = SequentialSampler(self.train_dataset)
|
||||
base_sampler = SequentialSampler(train_dataset)
|
||||
elif use_sample_packing:
|
||||
base_sampler = RandomSampler(self.train_dataset)
|
||||
base_sampler = RandomSampler(train_dataset)
|
||||
else:
|
||||
# Default to parent class implementation for standard random sampling
|
||||
return super()._get_train_sampler()
|
||||
return super()._get_train_sampler(train_dataset)
|
||||
|
||||
# Apply multipack wrapper if needed
|
||||
if use_sample_packing:
|
||||
@@ -152,6 +188,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
If the dataset is non-empty, a sampler is returned, the type of which
|
||||
depends on the passed training args.
|
||||
"""
|
||||
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L1065C9-L1066C24
|
||||
if eval_dataset is None or not has_length(eval_dataset):
|
||||
return None
|
||||
|
||||
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
||||
use_multipack = (
|
||||
self.args.sample_packing and self.args.eval_sample_packing is not False
|
||||
@@ -187,6 +227,14 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
|
||||
if dataset.column_names and "length" in dataset.column_names:
|
||||
dataset = dataset.remove_columns(["length"])
|
||||
if (
|
||||
dataset.column_names
|
||||
and "position_ids" in dataset.column_names
|
||||
and "attention_mask" in dataset.column_names
|
||||
and self.args.sample_packing
|
||||
and self.args.sample_packing_drop_attention_mask
|
||||
):
|
||||
dataset = dataset.remove_columns(["attention_mask"])
|
||||
|
||||
if isinstance(dataset, datasets.Dataset):
|
||||
if is_training:
|
||||
@@ -220,7 +268,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
}
|
||||
|
||||
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["drop_last"] = get_not_null(
|
||||
self.args.dataloader_drop_last, True
|
||||
)
|
||||
if sampler_fn is not None:
|
||||
sampler = sampler_fn(dataset)
|
||||
if isinstance(sampler, BatchSampler):
|
||||
@@ -251,9 +301,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
# fmt: off
|
||||
if dataloader_key is not None and self.args.dataloader_persistent_workers:
|
||||
if hasattr(self, "_eval_dataloaders"):
|
||||
self._eval_dataloaders[dataloader_key] = dataloader # type: ignore # pylint: disable=access-member-before-definition
|
||||
self._eval_dataloaders[dataloader_key] = dataloader # type: ignore
|
||||
else:
|
||||
self._eval_dataloaders = {dataloader_key: dataloader} # pylint: disable=attribute-defined-outside-init
|
||||
self._eval_dataloaders = {dataloader_key: dataloader}
|
||||
# fmt: on
|
||||
|
||||
return self.accelerator.prepare(dataloader)
|
||||
@@ -295,6 +345,17 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
# outputs = model(**inputs)
|
||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
|
||||
# track number of tokens for tokens per second calculation
|
||||
if self.args.include_tkps:
|
||||
inputs_key = "labels" if "labels" in inputs else "input_ids"
|
||||
if hasattr(self.state, "num_tokens"):
|
||||
self.state.num_tokens = (
|
||||
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
|
||||
)
|
||||
else:
|
||||
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
|
||||
|
||||
if self.args.orpo_alpha:
|
||||
return self.orpo_compute_loss(
|
||||
model,
|
||||
@@ -310,6 +371,11 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
@override
|
||||
def evaluate(self, *args, **kwargs):
|
||||
LOG.info("Running evaluation step...")
|
||||
return super().evaluate(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||
concatenated_batch = {}
|
||||
@@ -409,7 +475,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
num_items_in_batch=None,
|
||||
):
|
||||
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
|
||||
inputs,
|
||||
@@ -486,26 +552,32 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
|
||||
@wraps(Trainer.create_accelerator_and_postprocess)
|
||||
def create_accelerator_and_postprocess(self):
|
||||
res = super().create_accelerator_and_postprocess()
|
||||
# cleanup the PartialState states so Accelerate automatically configures everything from the env vars
|
||||
accelerator_config = self.args.accelerator_config.to_dict()
|
||||
use_configured_state = accelerator_config.get("use_configured_state", False)
|
||||
if not use_configured_state:
|
||||
AcceleratorState._reset_state(reset_partial_state=True)
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
if (
|
||||
"limit_all_gathers" in self.args.fsdp_config
|
||||
and self.args.fsdp_config["limit_all_gathers"]
|
||||
):
|
||||
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
||||
|
||||
return res
|
||||
super().create_accelerator_and_postprocess()
|
||||
|
||||
def additional_accelerator_args(
|
||||
self, fp8=None, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
ret_kwargs = {}
|
||||
if fp8:
|
||||
from accelerate.utils import AORecipeKwargs
|
||||
from torchao.float8 import Float8LinearConfig
|
||||
|
||||
# By default, Float8LinearConfig is instantiated using the "tensorwise"
|
||||
# scaling strategy. See more details here:
|
||||
# https://github.com/pytorch/ao/tree/main/torchao/float8.
|
||||
config = Float8LinearConfig(
|
||||
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
|
||||
force_recompute_fp8_weight_in_bwd=enable_fsdp_float8_all_gather is True,
|
||||
)
|
||||
|
||||
ret_kwargs["mixed_precision"] = "fp8"
|
||||
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs()]
|
||||
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs(config=config)] # type: ignore
|
||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
|
||||
|
||||
return ret_kwargs
|
||||
@@ -520,18 +592,61 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
|
||||
for key, metric_data in self._stored_metrics[train_eval].items():
|
||||
values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
|
||||
reduction_type = metric_data["reduction"]
|
||||
|
||||
fn = REDUCTION_FNS.get(reduction_type)
|
||||
if fn is None:
|
||||
raise NotImplementedError(
|
||||
"Metric reduction must be one of [mean, min, max, sum]"
|
||||
)
|
||||
logs[key] = round(fn(values).item(), 4)
|
||||
|
||||
if is_main_process():
|
||||
# Add memory usage
|
||||
try:
|
||||
active, allocated, reserved = get_gpu_memory_usage()
|
||||
logs["memory/max_active (GiB)"] = round(active, 2)
|
||||
logs["memory/max_allocated (GiB)"] = round(allocated, 2)
|
||||
logs["memory/device_reserved (GiB)"] = round(reserved, 2)
|
||||
except (ValueError, TypeError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
if self.args.include_tkps and train_eval == "train":
|
||||
# each rank will log its own tokens per second
|
||||
# for logging_steps > 1 we obtain a moving average of this metric
|
||||
logs["tokens_per_second_per_gpu"] = round(
|
||||
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
||||
)
|
||||
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
return super().log(logs, start_time)
|
||||
|
||||
def store_metrics(
|
||||
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
self,
|
||||
metrics: dict[str, float] | dict[str, tuple[int | float, str]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
reduction: Literal["mean", "min", "max", "sum"] = "mean",
|
||||
) -> None:
|
||||
"""
|
||||
Store metrics with specified reduction type.
|
||||
|
||||
Args:
|
||||
metrics: Dictionary of metric names to values, or metric names to (value,
|
||||
reduction_type) tuples.
|
||||
train_eval: Whether this is for training or evaluation.
|
||||
"""
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
if isinstance(value, tuple):
|
||||
value, _reduction = value # type: ignore[assignment]
|
||||
else:
|
||||
value, _reduction = value, reduction
|
||||
|
||||
self._stored_metrics[train_eval][key]["values"].append(value)
|
||||
self._stored_metrics[train_eval][key]["reduction"] = _reduction
|
||||
|
||||
def _save_checkpoint(self, model, trial, **kwargs):
|
||||
# make sure the checkpoint dir exists, since trainer is flakey
|
||||
@@ -540,3 +655,69 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
# If we are executing this function, we are the process zero, so we don't check for that.
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
LOG.info(f"Saving model checkpoint to {output_dir}")
|
||||
supported_classes = (
|
||||
(PreTrainedModel,)
|
||||
if not is_peft_available()
|
||||
else (PreTrainedModel, PeftModel)
|
||||
)
|
||||
# Save a trained model and configuration using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
if not isinstance(self.model, supported_classes):
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
if isinstance(
|
||||
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
|
||||
supported_classes,
|
||||
):
|
||||
self.accelerator.unwrap_model(
|
||||
self.model, keep_torch_compile=False
|
||||
).save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
|
||||
)
|
||||
if self.args.save_safetensors:
|
||||
safetensors.torch.save_file(
|
||||
state_dict,
|
||||
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
|
||||
metadata={"format": "pt"},
|
||||
)
|
||||
else:
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
self.model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
is_main_process=self.accelerator.is_main_process,
|
||||
)
|
||||
|
||||
if self.processing_class is not None:
|
||||
self.processing_class.save_pretrained(output_dir)
|
||||
elif (
|
||||
self.data_collator is not None
|
||||
and hasattr(self.data_collator, "tokenizer")
|
||||
and self.data_collator.tokenizer is not None
|
||||
):
|
||||
LOG.info(
|
||||
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
||||
)
|
||||
save_jinja_files = True
|
||||
if self.axolotl_cfg:
|
||||
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
|
||||
self.data_collator.tokenizer.save_pretrained(
|
||||
output_dir, save_jinja_files=save_jinja_files
|
||||
)
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
|
||||
@@ -22,10 +22,18 @@ class DPOStrategy:
|
||||
training_args_kwargs = {}
|
||||
if cfg.rl is RLType.IPO:
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
# Label smoothing is not compatible with IPO
|
||||
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
|
||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
||||
if cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||
if cfg.dpo_padding_free is not None:
|
||||
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
|
||||
if cfg.dpo_norm_loss is not None:
|
||||
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
||||
if cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
||||
return training_args_kwargs
|
||||
|
||||
@@ -14,3 +14,5 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
|
||||
dpo_norm_loss: bool | None = False
|
||||
|
||||
@@ -8,7 +8,11 @@ import torch
|
||||
from torch import nn
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.mixins import (
|
||||
DistributedParallelMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
)
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
@@ -17,7 +21,12 @@ from axolotl.core.trainers.utils import (
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, DPOTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DPOTrainer,
|
||||
DistributedParallelMixin,
|
||||
):
|
||||
"""Extend the base DPOTrainer for axolotl helpers."""
|
||||
|
||||
@@ -83,3 +92,20 @@ class AxolotlDPOTrainer(
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
|
||||
def concatenated_forward(
|
||||
self,
|
||||
model: nn.Module,
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
is_ref_model: bool = False,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if self.args.dpo_norm_loss:
|
||||
# fmt: off
|
||||
loss_type: str = self.loss_type # type: ignore[has-type]
|
||||
# fmt: on
|
||||
# concatenated_forward handles avg token logprob for ipo case already
|
||||
self.loss_type = "ipo"
|
||||
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
|
||||
self.loss_type = loss_type
|
||||
return res
|
||||
return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
|
||||
|
||||
@@ -2,8 +2,11 @@
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from requests import HTTPError
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||
@@ -14,6 +17,7 @@ from axolotl.core.trainers.grpo.trainer import (
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
from axolotl.utils.schemas.vllm import VllmConfig
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -41,9 +45,20 @@ class GRPOStrategy:
|
||||
return grpo_args_kwargs
|
||||
|
||||
trl: TRLConfig = cfg.trl # type: ignore
|
||||
vllm_cfg: VllmConfig = cfg.vllm # type: ignore
|
||||
|
||||
if trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||
if trl.vllm_mode:
|
||||
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
|
||||
if trl.vllm_mode == "colocate":
|
||||
grpo_args_kwargs["enable_sleep_mode"] = trl.vllm_enable_sleep_mode # type: ignore[attr-defined]
|
||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||
vllm_cfg.gpu_memory_utilization
|
||||
)
|
||||
grpo_args_kwargs["vllm_tensor_parallel_size"] = (
|
||||
vllm_cfg.tensor_parallel_size
|
||||
)
|
||||
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host # type: ignore[attr-defined]
|
||||
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port # type: ignore[attr-defined]
|
||||
if trl.vllm_server_timeout:
|
||||
@@ -69,8 +84,13 @@ class GRPOStrategy:
|
||||
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||
|
||||
if cfg.sequence_parallel_degree > 1:
|
||||
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
|
||||
if cfg.context_parallel_size > 1:
|
||||
grpo_args_kwargs["context_parallel_size"] = cfg.context_parallel_size
|
||||
|
||||
if trl.importance_sampling_level is not None:
|
||||
grpo_args_kwargs["importance_sampling_level"] = (
|
||||
trl.importance_sampling_level
|
||||
)
|
||||
|
||||
if trl.reward_weights:
|
||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||
@@ -109,9 +129,7 @@ class GRPOStrategy:
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
def set_trainer_args(
|
||||
cls, cfg: DictDefault
|
||||
) -> list[Any]: # pylint: disable=unused-argument
|
||||
def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:
|
||||
trainer_args = []
|
||||
if cfg.trl and cfg.trl.reward_funcs:
|
||||
reward_funcs = []
|
||||
@@ -132,13 +150,13 @@ class GRPOStrategy:
|
||||
return trainer_kwargs
|
||||
|
||||
@classmethod
|
||||
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||
def get_collator(cls, *args, **kwargs):
|
||||
# No data collation is needed in GRPO, handled by trl's trainer __init__
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_blocklist_args_kwargs(cls) -> list[str]:
|
||||
return ["dataset_num_proc", "max_length"]
|
||||
return ["dataset_num_proc", "max_length", "include_tokens_per_second"]
|
||||
|
||||
@classmethod
|
||||
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
|
||||
@@ -168,9 +186,18 @@ class GRPOStrategy:
|
||||
"Reward function must accept at least two arguments: prompts: list and completions: list"
|
||||
)
|
||||
return reward_func
|
||||
except ModuleNotFoundError:
|
||||
except ModuleNotFoundError as exc:
|
||||
# the user has passed a string (ideally indicating the path of a reward model)
|
||||
LOG.info(
|
||||
f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
|
||||
)
|
||||
return reward_func_fqn
|
||||
# check if it's a local dir path and not empty dir to a reward model
|
||||
pretrained_log_msg = f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
|
||||
if os.path.isdir(reward_func_fqn) and os.listdir(reward_func_fqn):
|
||||
LOG.info(pretrained_log_msg)
|
||||
return reward_func_fqn
|
||||
try:
|
||||
snapshot_download(reward_func_fqn, repo_type="model")
|
||||
LOG.info(pretrained_log_msg)
|
||||
return reward_func_fqn
|
||||
except HTTPError:
|
||||
raise ValueError(
|
||||
f"Reward function {reward_func_fqn} not found."
|
||||
) from exc
|
||||
|
||||
@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""Axolotl GRPO Config for GRPO training"""
|
||||
|
||||
sequence_parallel_degree: int | None = None
|
||||
context_parallel_size: int | None = None
|
||||
|
||||
@@ -20,7 +20,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
- Data is properly distributed across SP groups.
|
||||
|
||||
In the table below, the values represent dataset indices. Each SP group has
|
||||
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
|
||||
`context_parallel_size = 2` GPUs working together on the same data. There are 2
|
||||
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
||||
|
||||
Sequence Parallel Groups
|
||||
@@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
rank: Rank of current process.
|
||||
batch_size: Number of samples per batch.
|
||||
repeat_count: How many times to repeat the full sampling process.
|
||||
sequence_parallel_degree: Number of ranks in a sequence parallel group.
|
||||
context_parallel_size: Number of ranks in a sequence parallel group.
|
||||
shuffle: Whether to shuffle the dataset.
|
||||
seed: Random seed for shuffling.
|
||||
drop_last: Whether to drop the last incomplete batch.
|
||||
@@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
rank: int,
|
||||
batch_size: int = 1,
|
||||
repeat_count: int = 1,
|
||||
sequence_parallel_degree: int = 1,
|
||||
context_parallel_size: int = 1,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
@@ -77,9 +77,9 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
self.rank = rank
|
||||
|
||||
# Sequence parallelism parameters
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.num_sp_groups = world_size // sequence_parallel_degree
|
||||
self.sp_group_id = rank // sequence_parallel_degree
|
||||
self.context_parallel_size = context_parallel_size
|
||||
self.num_sp_groups = world_size // context_parallel_size
|
||||
self.sp_group_id = rank // context_parallel_size
|
||||
|
||||
# Adjust dataset size for distributed sampling
|
||||
self.num_samples = len(self.dataset)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"""Axolotl GRPO trainers (with and without sequence parallelism handling)"""
|
||||
|
||||
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
||||
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
@@ -42,17 +41,25 @@ from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
||||
from trl.trainer.utils import pad
|
||||
|
||||
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.mixins import (
|
||||
DistributedParallelMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
)
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
|
||||
|
||||
if is_peft_available():
|
||||
# pylint: disable=unused-import
|
||||
from peft import PeftConfig
|
||||
|
||||
|
||||
class AxolotlGRPOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, GRPOTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
GRPOTrainer,
|
||||
):
|
||||
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||
|
||||
@@ -99,7 +106,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
|
||||
# Get number of SP groups (number of processes divided by SP degree)
|
||||
num_processes = self.accelerator.num_processes
|
||||
num_sp_groups = num_processes // self.args.sequence_parallel_degree
|
||||
num_sp_groups = num_processes // self.args.context_parallel_size
|
||||
|
||||
# Calculate batch size per SP group (not per process)
|
||||
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
|
||||
@@ -129,7 +136,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
|
||||
f"With sequence parallelism (degree {self.args.context_parallel_size}), "
|
||||
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
||||
f"must be evenly divisible by the number of generations per prompt "
|
||||
f"({self.num_generations}). Given the current eval batch size, "
|
||||
@@ -166,9 +173,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
rank=self.rank,
|
||||
batch_size=effective_batch_size
|
||||
// self.num_generations
|
||||
// self.args.sequence_parallel_degree,
|
||||
// self.args.context_parallel_size,
|
||||
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
|
||||
sequence_parallel_degree=self.args.sequence_parallel_degree,
|
||||
context_parallel_size=self.args.context_parallel_size,
|
||||
shuffle=True,
|
||||
seed=self.args.seed,
|
||||
drop_last=True,
|
||||
@@ -215,7 +222,11 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
if not is_eval:
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
dataloader_params["worker_init_fn"] = partial(
|
||||
seed_worker,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
rank=self.args.process_index,
|
||||
)
|
||||
|
||||
# Create the dataloader
|
||||
dataloader = DataLoader(dataset, **dataloader_params)
|
||||
@@ -230,7 +241,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||
# slice each batch along the sequence dimension).
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if self.args.context_parallel_size > 1:
|
||||
return dataloader
|
||||
|
||||
# Otherwise prepare with accelerator
|
||||
@@ -239,7 +250,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
"""Get dataloader for training"""
|
||||
train_dataset = self.train_dataset
|
||||
# pylint: disable=access-member-before-definition
|
||||
|
||||
data_collator = self.data_collator # type: ignore
|
||||
|
||||
# Handle dataset preprocessing
|
||||
@@ -252,7 +263,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
train_dataset, description="training"
|
||||
)
|
||||
else:
|
||||
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||
self.data_collator = self._get_collator_with_removed_columns(
|
||||
data_collator,
|
||||
description="training",
|
||||
)
|
||||
@@ -294,33 +305,34 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
# Generate completions using either vLLM or regular generation
|
||||
if self.args.use_vllm:
|
||||
# First, have main process load weights if needed
|
||||
# pylint: disable=access-member-before-definition
|
||||
|
||||
if self.state.global_step != self._last_loaded_step: # type: ignore[has-type]
|
||||
self._move_model_to_vllm()
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
|
||||
self._last_loaded_step = self.state.global_step
|
||||
|
||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
if self.accelerator.is_main_process:
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if self.args.context_parallel_size > 1:
|
||||
# Calculate sequence parallel group information
|
||||
world_size = self.accelerator.num_processes
|
||||
sequence_parallel_degree = self.args.sequence_parallel_degree
|
||||
num_sp_groups = world_size // sequence_parallel_degree
|
||||
context_parallel_size = self.args.context_parallel_size
|
||||
num_sp_groups = world_size // context_parallel_size
|
||||
|
||||
# Since processes in the same SP group have the same prompts, we need to ensure
|
||||
# we only take one copy of each prompt from each SP group
|
||||
ordered_set_of_prompts = []
|
||||
for sp_group_id in range(num_sp_groups):
|
||||
# Get the first process from each SP group (typically the group leader)
|
||||
group_leader_rank = sp_group_id * sequence_parallel_degree
|
||||
group_leader_rank = sp_group_id * context_parallel_size
|
||||
|
||||
# Extract prompts from this SP group, accounting for num_generations duplicates
|
||||
# We only need prompts from one rank in each SP group
|
||||
group_prompts = all_prompts_text[
|
||||
group_leader_rank
|
||||
* len(prompts_text) : (group_leader_rank + 1)
|
||||
group_leader_rank * len(prompts_text) : (
|
||||
group_leader_rank + 1
|
||||
)
|
||||
* len(prompts_text) : self.num_generations
|
||||
]
|
||||
|
||||
@@ -330,7 +342,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||
# prompt individually.
|
||||
ordered_set_of_prompts = all_prompts_text[
|
||||
:: self.num_generations * self.args.sequence_parallel_degree
|
||||
:: self.num_generations * self.args.context_parallel_size
|
||||
]
|
||||
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
@@ -347,14 +359,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
)
|
||||
else:
|
||||
completion_ids = [None] * (
|
||||
len(all_prompts_text) // self.args.sequence_parallel_degree
|
||||
len(all_prompts_text) // self.args.context_parallel_size
|
||||
)
|
||||
|
||||
# Broadcast the completions from the main process to all processes
|
||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||
|
||||
# Determine the appropriate slice based on sequence parallelism
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if self.args.context_parallel_size > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
@@ -471,7 +483,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
)
|
||||
if is_conversational(inputs[0]):
|
||||
completions = []
|
||||
for prompt, completion in zip(prompts, completions_text):
|
||||
for prompt, completion in zip(prompts, completions_text, strict=False):
|
||||
bootstrap = (
|
||||
prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
|
||||
)
|
||||
@@ -489,6 +501,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
self.reward_funcs,
|
||||
self.reward_processing_classes,
|
||||
self.reward_func_names,
|
||||
strict=False,
|
||||
)
|
||||
):
|
||||
with profiling_context(self, reward_func_name):
|
||||
@@ -497,14 +510,17 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
): # Module instead of PretrainedModel for compat with compiled models
|
||||
if is_conversational(inputs[0]):
|
||||
messages = [
|
||||
{"messages": p + c} for p, c in zip(prompts, completions)
|
||||
{"messages": p + c}
|
||||
for p, c in zip(prompts, completions, strict=False)
|
||||
]
|
||||
texts = [
|
||||
apply_chat_template(x, reward_processing_class)["text"]
|
||||
for x in messages
|
||||
]
|
||||
else:
|
||||
texts = [p + c for p, c in zip(prompts, completions)]
|
||||
texts = [
|
||||
p + c for p, c in zip(prompts, completions, strict=False)
|
||||
]
|
||||
reward_inputs = reward_processing_class(
|
||||
text=texts,
|
||||
return_tensors="pt",
|
||||
@@ -550,7 +566,8 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
row_reward_kwargs["completion"] = completions[nan_row_idx]
|
||||
warnings.warn(
|
||||
f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
|
||||
"Please ensure that at least one reward function returns a valid reward."
|
||||
"Please ensure that at least one reward function returns a valid reward.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
|
||||
@@ -578,7 +595,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
advantages = advantages / (std_grouped_rewards + 1e-4)
|
||||
|
||||
# Slice to keep only the local part of the data
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if self.args.context_parallel_size > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
|
||||
@@ -14,8 +14,8 @@ class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False, # pylint: disable=unused-argument
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None,
|
||||
):
|
||||
input_ids = inputs.pop("input_ids")
|
||||
lm_logits = model(input_ids).logits
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
"""Init for axolotl.core.trainers.mixins"""
|
||||
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .activation_checkpointing import ActivationOffloadingMixin
|
||||
from .checkpoints import CheckpointSaveMixin
|
||||
from .distributed_parallel import DistributedParallelMixin
|
||||
from .optimizer import OptimizerMixin
|
||||
from .packing import PackingMixin
|
||||
from .rng_state_loader import RngLoaderMixin
|
||||
from .scheduler import SchedulerMixin
|
||||
|
||||
217
src/axolotl/core/trainers/mixins/activation_checkpointing.py
Normal file
217
src/axolotl/core/trainers/mixins/activation_checkpointing.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Trainer mixin for activation checkpointing w offloading
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
|
||||
from peft import PeftModel
|
||||
from torch import nn
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
apply_activation_checkpointing,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from transformers import GradientCheckpointingLayer, Trainer
|
||||
from trl.models.activation_offloading import (
|
||||
NoOpManager,
|
||||
OffloadActivations,
|
||||
get_act_offloading_ctx_manager,
|
||||
)
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class ActivationOffloadingMixin(Trainer):
|
||||
"""
|
||||
Trainer mixin class for activation checkpointing w offloading
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.args.activation_offloading:
|
||||
if isinstance(self.model, PeftModel):
|
||||
self.activation_offload_context = get_lora_act_offloading_ctx_manager(
|
||||
self.model, use_streams=True
|
||||
)
|
||||
else:
|
||||
self.activation_offload_context = get_act_offloading_ctx_manager(
|
||||
self.model, use_streams=True
|
||||
)
|
||||
else:
|
||||
self.activation_offload_context = contextlib.nullcontext()
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
with self.activation_offload_context:
|
||||
return super().training_step(*args, **kwargs)
|
||||
|
||||
|
||||
def ac_wrap_hf_model(model: nn.Module, **kwargs):
|
||||
auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
|
||||
apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)
|
||||
|
||||
|
||||
def get_lora_act_offloading_ctx_manager(
|
||||
model: nn.Module,
|
||||
use_pin_memory: bool = True,
|
||||
use_streams: bool = True,
|
||||
min_offload_size: int = 1024,
|
||||
max_fwd_stash_size: int = 5,
|
||||
warn_if_no_head: bool = True,
|
||||
) -> OffloadActivations:
|
||||
"""
|
||||
Returns the activation offloading context manager for the model. All but the last output Linear in every step will
|
||||
be offloaded.
|
||||
|
||||
If activation offloading is enabled, we return the OffloadActivations context manager. If activation offloading is
|
||||
disabled, we return a NoOpManager context manager.
|
||||
|
||||
Args:
|
||||
model (`nn.Module`):
|
||||
Model to wrap with the activation offloading context manager.
|
||||
use_pin_memory (`bool`, *optional*, defaults to `True`):
|
||||
Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to
|
||||
be moved back onto GPU more quickly but is a limited resource.
|
||||
use_streams (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use streams for performance optimization where the communications get overlapped with the
|
||||
computation. Requires a torch build after torch-2.5.0.
|
||||
min_offload_size (`int`, *optional*, defaults to `1024`):
|
||||
Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we
|
||||
do not want to waste bandwidth and resources moving it to CPU and back.
|
||||
max_fwd_stash_size (`int`, *optional*, defaults to `5`):
|
||||
Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during
|
||||
the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow
|
||||
more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping
|
||||
alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing
|
||||
runtime.
|
||||
warn_if_no_head (`bool`, *optional*, defaults to `True`):
|
||||
Whether to warn if no output head is detected. If set to `False`, no warning will be raised if no output
|
||||
head is detected.
|
||||
|
||||
Returns:
|
||||
`contextlib.ContextDecorator`:
|
||||
Activation offloading context manager for the model.
|
||||
"""
|
||||
|
||||
activations_handling_ctx = OffloadActivations(
|
||||
use_pin_memory=use_pin_memory,
|
||||
use_streams=use_streams,
|
||||
min_offload_size=min_offload_size,
|
||||
max_fwd_stash_size=max_fwd_stash_size,
|
||||
)
|
||||
|
||||
# Below is our hack to disable offloading the last output Linear in every
|
||||
# step, as the cost for offloading the activation and then soon after bringing
|
||||
# it back is expensive.
|
||||
output_head_detected = False
|
||||
noop_ctx = NoOpManager()
|
||||
|
||||
# Try to get the actual model if it's wrapped
|
||||
unwrapped_model = model
|
||||
if hasattr(unwrapped_model, "module"):
|
||||
unwrapped_model = unwrapped_model.module
|
||||
# check for PEFT models
|
||||
if hasattr(unwrapped_model, "base_model") and hasattr(
|
||||
unwrapped_model, "peft_config"
|
||||
):
|
||||
unwrapped_model = unwrapped_model.base_model
|
||||
|
||||
# Check for different types of output heads
|
||||
if hasattr(unwrapped_model, "output"):
|
||||
if isinstance(unwrapped_model.output, nn.Module):
|
||||
unwrapped_model.output.register_forward_pre_hook(
|
||||
lambda *args: noop_ctx.__enter__()
|
||||
)
|
||||
unwrapped_model.output.register_forward_hook(
|
||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
||||
)
|
||||
output_head_detected = True
|
||||
elif hasattr(unwrapped_model.output, "linear") and isinstance(
|
||||
unwrapped_model.output.linear, nn.Module
|
||||
):
|
||||
unwrapped_model.output.linear.register_forward_pre_hook(
|
||||
lambda *args: noop_ctx.__enter__()
|
||||
)
|
||||
unwrapped_model.output.linear.register_forward_hook(
|
||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
||||
)
|
||||
output_head_detected = True
|
||||
|
||||
# Check for HuggingFace model output heads
|
||||
elif hasattr(unwrapped_model, "lm_head"):
|
||||
unwrapped_model.lm_head.register_forward_pre_hook(
|
||||
lambda *args: noop_ctx.__enter__()
|
||||
)
|
||||
unwrapped_model.lm_head.register_forward_hook(
|
||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
||||
)
|
||||
output_head_detected = True
|
||||
|
||||
# Check for decoder-based models
|
||||
elif hasattr(unwrapped_model, "decoder"):
|
||||
decoder = unwrapped_model.decoder
|
||||
if hasattr(decoder, "output"):
|
||||
decoder.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
||||
decoder.output.register_forward_hook(
|
||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
||||
)
|
||||
output_head_detected = True
|
||||
# Some models have lm_head in the decoder
|
||||
elif hasattr(decoder, "lm_head"):
|
||||
decoder.lm_head.register_forward_pre_hook(
|
||||
lambda *args: noop_ctx.__enter__()
|
||||
)
|
||||
decoder.lm_head.register_forward_hook(
|
||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
||||
)
|
||||
output_head_detected = True
|
||||
|
||||
# Check for transformer models with final layer norm
|
||||
elif hasattr(unwrapped_model, "final_layer_norm") or hasattr(
|
||||
unwrapped_model, "ln_f"
|
||||
):
|
||||
final_norm = (
|
||||
getattr(unwrapped_model, "final_layer_norm", None) or unwrapped_model.ln_f
|
||||
)
|
||||
final_norm.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
||||
final_norm.register_forward_hook(
|
||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
||||
)
|
||||
output_head_detected = True
|
||||
|
||||
# Check for models with head module
|
||||
elif hasattr(unwrapped_model, "head") and isinstance(
|
||||
unwrapped_model.head, nn.Module
|
||||
):
|
||||
unwrapped_model.head.register_forward_pre_hook(
|
||||
lambda *args: noop_ctx.__enter__()
|
||||
)
|
||||
unwrapped_model.head.register_forward_hook(
|
||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
||||
)
|
||||
output_head_detected = True
|
||||
|
||||
if not output_head_detected and warn_if_no_head:
|
||||
LOG.warning(
|
||||
"During activation offloading, no output head was detected. If your model has an output head, it will be "
|
||||
"offloaded. This usually greatly slows training, given the large vocabulary size. To change this "
|
||||
"behavior, set your output head as model.output and make it an nn.Module. You can disable this warning by "
|
||||
"passing `warn_if_no_head=False`."
|
||||
)
|
||||
|
||||
for name, module in unwrapped_model.named_modules():
|
||||
# Disable offloading for any Liger modules
|
||||
if "liger" in name.lower():
|
||||
module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
||||
module.register_forward_hook(
|
||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
||||
)
|
||||
# disable offloading for any submodules to fix LoRA training
|
||||
if name.endswith("._checkpoint_wrapped_module"):
|
||||
for _, sub_module in module.named_modules():
|
||||
sub_module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
||||
sub_module.register_forward_hook(
|
||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
||||
)
|
||||
|
||||
return activations_handling_ctx
|
||||
23
src/axolotl/core/trainers/mixins/checkpoints.py
Normal file
23
src/axolotl/core/trainers/mixins/checkpoints.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Custom handling to not fail training if fsdp optimizer is not savable"""
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class CheckpointSaveMixin(Trainer):
|
||||
"""Mixin to handle saving the optimizer and scheduler if they are not savable."""
|
||||
|
||||
def _save_optimizer_and_scheduler(self, output_dir):
|
||||
try:
|
||||
super()._save_optimizer_and_scheduler(output_dir)
|
||||
except (NotImplementedError, KeyError) as exc:
|
||||
# TODO: fix fsdp2 optimizer saving
|
||||
LOG.warning_once(
|
||||
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
|
||||
"Optimizer and scheduler states were not saved - resuming from checkpoints "
|
||||
"for this training run will not be possible.",
|
||||
main_process_only=True,
|
||||
)
|
||||
32
src/axolotl/core/trainers/mixins/distributed_parallel.py
Normal file
32
src/axolotl/core/trainers/mixins/distributed_parallel.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Mixin for correctly saving fsdp
|
||||
"""
|
||||
|
||||
from accelerate import PartialState
|
||||
from transformers import Trainer
|
||||
|
||||
|
||||
class DistributedParallelMixin(Trainer):
|
||||
"""
|
||||
Mixin for correctly saving fsdp
|
||||
"""
|
||||
|
||||
def _save(self, output_dir: str | None = None, state_dict=None):
|
||||
if (
|
||||
state_dict is None
|
||||
and self.accelerator.parallelism_config
|
||||
and self.accelerator.parallelism_config.dp_shard_enabled
|
||||
):
|
||||
state_dict = self.accelerator.get_state_dict(self.model)
|
||||
super()._save(output_dir, state_dict=state_dict)
|
||||
|
||||
def create_accelerator_and_postprocess(self):
|
||||
super().create_accelerator_and_postprocess()
|
||||
if (
|
||||
self.accelerator.distributed_type == "FSDP"
|
||||
and self.accelerator.state.fsdp_plugin is None
|
||||
):
|
||||
# handle Context Parallelism without FSDP
|
||||
self.accelerator.state.distributed_type = "MULTI_GPU"
|
||||
self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU"
|
||||
PartialState().distributed_type = "MULTI_GPU"
|
||||
@@ -70,11 +70,11 @@ class OptimizerMixin(Trainer):
|
||||
}
|
||||
)
|
||||
if params["embeddings"]:
|
||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||
lr = optimizer_kwargs["lr"]
|
||||
if self.args.embedding_lr_scale:
|
||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||
lr *= self.args.embedding_lr_scale
|
||||
elif self.args.embedding_lr:
|
||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||
lr = self.args.embedding_lr
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["embeddings"].values()),
|
||||
@@ -143,7 +143,7 @@ class OptimizerMixin(Trainer):
|
||||
loraplus_lr_embedding = getattr(
|
||||
self.args, "loraplus_lr_embedding", 1e-6
|
||||
)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer = create_loraplus_optimizer(
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
@@ -185,17 +185,15 @@ class OptimizerMixin(Trainer):
|
||||
p.data_ptr(): p.numel() for p in module.parameters()
|
||||
}.values()
|
||||
)
|
||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
LOG.info(f"skipped {module}: {skipped / 2**20}M params")
|
||||
manager.register_module_override(
|
||||
module, "weight", {"optim_bits": 32}
|
||||
)
|
||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||
LOG.info(f"skipped: {skipped / 2**20}M params")
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
||||
|
||||
return self.optimizer
|
||||
|
||||
|
||||
20
src/axolotl/core/trainers/mixins/packing.py
Normal file
20
src/axolotl/core/trainers/mixins/packing.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Trainer mixin to support packing"""
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
|
||||
class PackingMixin(Trainer):
|
||||
"""
|
||||
Trainer mixin to support packing
|
||||
"""
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
super()._set_signature_columns_if_needed()
|
||||
if (
|
||||
self._signature_columns
|
||||
and self.args.sample_packing
|
||||
and self.args.sample_packing_drop_attention_mask
|
||||
):
|
||||
set_sig_columns = set(self._signature_columns)
|
||||
set_sig_columns.remove("attention_mask")
|
||||
self._signature_columns = list(set_sig_columns)
|
||||
@@ -7,6 +7,7 @@ from transformers.trainer import Trainer
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schedulers import (
|
||||
JaggedLRRestartScheduler,
|
||||
RexLR,
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
@@ -45,7 +46,7 @@ class SchedulerMixin(Trainer):
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
if self.lr_scheduler is None: # type: ignore
|
||||
# fmt: on
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
lr_scheduler: LRScheduler | None = plugin_manager.create_lr_scheduler(
|
||||
@@ -89,7 +90,7 @@ class SchedulerMixin(Trainer):
|
||||
LOG.warning(
|
||||
"Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
@@ -97,7 +98,7 @@ class SchedulerMixin(Trainer):
|
||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant(
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
@@ -106,14 +107,14 @@ class SchedulerMixin(Trainer):
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||
self.lr_scheduler = get_cosine_schedule_with_min_lr(
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning(
|
||||
@@ -123,4 +124,22 @@ class SchedulerMixin(Trainer):
|
||||
LOG.warning(
|
||||
"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
|
||||
if self.args.jagged_restart_steps:
|
||||
warmup_steps = (
|
||||
self.args.jagged_restart_warmup_steps or 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.jagged_restart_anneal_steps or 1
|
||||
)
|
||||
if not self.lr_scheduler:
|
||||
super().create_scheduler(num_training_steps, optimizer)
|
||||
self.lr_scheduler = JaggedLRRestartScheduler(
|
||||
optimizer,
|
||||
self.lr_scheduler,
|
||||
self.args.jagged_restart_steps,
|
||||
warmup_steps,
|
||||
anneal_steps,
|
||||
min_lr_scale=self.args.cosine_min_lr_ratio or 0.001,
|
||||
)
|
||||
|
||||
return self.lr_scheduler # type: ignore
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
"""Module for ReLoRA trainer"""
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
|
||||
|
||||
tag_names = ["axolotl", "relora"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: torch.optim.Optimizer | None = None,
|
||||
) -> LRScheduler:
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler: LRScheduler = super().create_scheduler(
|
||||
num_training_steps, optimizer
|
||||
)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler( # type: ignore
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
anneal_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler # type: ignore
|
||||
|
||||
return self.lr_scheduler # type: ignore
|
||||
@@ -1,81 +1,25 @@
|
||||
"""Module for TRL PPO trainer"""
|
||||
"""Module for TRL RL trainers"""
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from trl import (
|
||||
CPOTrainer,
|
||||
KTOTrainer,
|
||||
ORPOTrainer,
|
||||
PPOTrainer,
|
||||
PRMTrainer,
|
||||
RewardTrainer,
|
||||
)
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin
|
||||
from axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||
|
||||
|
||||
class TRLPPOTrainer(PPOTrainer):
|
||||
"""Wrapper for TRL PPO trainer to handle customizations"""
|
||||
|
||||
tag_names = ["axolotl", "ppo"]
|
||||
|
||||
def train(
|
||||
self,
|
||||
reward_pipe,
|
||||
resume_from_checkpoint=None, # pylint: disable=unused-argument
|
||||
):
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": self.tokenizer.eos_token_id,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
sent_kwargs = {
|
||||
"return_all_scores": True,
|
||||
"function_to_apply": "none",
|
||||
"batch_size": 16,
|
||||
}
|
||||
|
||||
for _, batch in tqdm(enumerate(self.dataloader)):
|
||||
query_tensors = batch["input_ids"]
|
||||
|
||||
# generate model response
|
||||
response_tensors, ref_response_tensors = self.generate(
|
||||
query_tensors,
|
||||
return_prompt=False,
|
||||
generate_ref_response=True,
|
||||
**generation_kwargs,
|
||||
)
|
||||
batch["response"] = self.tokenizer.batch_decode(response_tensors)
|
||||
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
|
||||
|
||||
# Compute sentiment score
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
pipe_outputs = reward_pipe(texts, **sent_kwargs)
|
||||
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
|
||||
ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
|
||||
ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs)
|
||||
ref_rewards = [
|
||||
torch.tensor(output[1]["score"]) for output in ref_pipe_outputs
|
||||
]
|
||||
batch["ref_rewards"] = ref_rewards
|
||||
|
||||
# Run PPO step
|
||||
stats = self.step(query_tensors, response_tensors, rewards)
|
||||
self.log_stats(
|
||||
stats,
|
||||
batch,
|
||||
rewards,
|
||||
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
||||
)
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
ORPOTrainer,
|
||||
):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
@@ -85,7 +29,12 @@ class AxolotlORPOTrainer(
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, KTOTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
KTOTrainer,
|
||||
):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
@@ -95,7 +44,12 @@ class AxolotlKTOTrainer(
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, CPOTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
CPOTrainer,
|
||||
):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
@@ -105,7 +59,12 @@ class AxolotlCPOTrainer(
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, RewardTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
RewardTrainer,
|
||||
):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
@@ -115,7 +74,12 @@ class AxolotlRewardTrainer(
|
||||
|
||||
|
||||
class AxolotlPRMTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, PRMTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
PRMTrainer,
|
||||
):
|
||||
"""
|
||||
Extend the base trl.PRMTrainer for axolotl helpers
|
||||
|
||||
@@ -2,238 +2,17 @@
|
||||
extra axolotl specific training args
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Type
|
||||
|
||||
from PIL.Image import Resampling
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
from axolotl.integrations.config import merge_training_args
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
"""
|
||||
Mixin class for the Axolotl training args.
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
model_type: Optional[str] = field(
|
||||
default=None, metadata={"help": "HF model configuration model_type."}
|
||||
)
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
pretraining: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Indicates to trainer whether we are doing continued pretraining."
|
||||
},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
sample_packing_sequentially: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||
},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
sample_packing_bin_size: int = field(
|
||||
default=200,
|
||||
metadata={
|
||||
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
sample_packing_group_size: int = field(
|
||||
default=100000,
|
||||
metadata={
|
||||
"help": "The number of samples to group together for packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_prune_ratio: Optional[float] = field(
|
||||
default=0.9,
|
||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
do_causal_lm_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||
)
|
||||
cosine_min_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
||||
)
|
||||
cosine_constant_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
||||
)
|
||||
loraplus_lr_embedding: Optional[float] = field(
|
||||
default=1e-6,
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
)
|
||||
embedding_lr_scale: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||
)
|
||||
lr_groups: Optional[list[dict]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Specify learning rate groups for with different LRs."},
|
||||
)
|
||||
embedding_lr: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||
)
|
||||
qlora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
)
|
||||
orpo_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
)
|
||||
lisa_n_layers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "the number of activate layers in LISA"},
|
||||
)
|
||||
lisa_step_interval: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to switch layers in LISA"},
|
||||
)
|
||||
lisa_layers_attribute: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "path under the model to access the layers"},
|
||||
)
|
||||
curriculum_sampling: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||
)
|
||||
alternate_lr_scheduler_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||
},
|
||||
)
|
||||
chat_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Chat template converting chat messages to text"},
|
||||
)
|
||||
|
||||
kd_ce_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
kd_alpha: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The alpha scaling parameter for KD loss"},
|
||||
)
|
||||
|
||||
kd_temperature: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
"help": "the temperature parameter for KL divergence loss when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
kd_zscore_base_temp: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "the base temperature parameter for KL divergence with z-score when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
kd_top_k_before_softmax: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
adam_beta3: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
|
||||
},
|
||||
)
|
||||
adam_epsilon2: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
|
||||
},
|
||||
)
|
||||
|
||||
# multi-modal section
|
||||
|
||||
image_size: int | tuple[int, int] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The size of the image to resize to"},
|
||||
)
|
||||
|
||||
image_resize_algorithm: Resampling | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The algorithm to use for image resizing"},
|
||||
)
|
||||
|
||||
# end of multi-modal section
|
||||
AxolotlTrainingMixins: Type = merge_training_args()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
265
src/axolotl/core/training_args_base.py
Normal file
265
src/axolotl/core/training_args_base.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Base Axolotl Training Mixins shared across various trainer configs
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from PIL.Image import Resampling
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
"""
|
||||
Mixin class for the Axolotl training args.
|
||||
"""
|
||||
|
||||
model_type: Optional[str] = field(
|
||||
default=None, metadata={"help": "HF model configuration model_type."}
|
||||
)
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
pretraining: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Indicates to trainer whether we are doing continued pretraining."
|
||||
},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
sample_packing_sequentially: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||
},
|
||||
)
|
||||
sample_packing_mp_start_method: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The multiprocessing start method to use."},
|
||||
)
|
||||
sample_packing_drop_attention_mask: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Drop attention mask from inputs when using packing."},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
)
|
||||
include_tkps: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to include tokens per second in the training metrics."
|
||||
},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
sample_packing_bin_size: int = field(
|
||||
default=200,
|
||||
metadata={
|
||||
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
sample_packing_group_size: int = field(
|
||||
default=100000,
|
||||
metadata={
|
||||
"help": "The number of samples to group together for packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for data processing"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_prune_ratio: Optional[float] = field(
|
||||
default=0.9,
|
||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||
)
|
||||
jagged_restart_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for jagged restarts"},
|
||||
)
|
||||
jagged_restart_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "how many warmup steps to take after reset for jagged restarts"
|
||||
},
|
||||
)
|
||||
jagged_restart_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "how many anneal steps to take before reset for jagged restarts"
|
||||
},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
do_causal_lm_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||
)
|
||||
cosine_min_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
||||
)
|
||||
cosine_constant_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
||||
)
|
||||
loraplus_lr_embedding: Optional[float] = field(
|
||||
default=1e-6,
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
)
|
||||
embedding_lr_scale: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||
)
|
||||
lr_groups: Optional[list[dict]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Specify learning rate groups for with different LRs."},
|
||||
)
|
||||
embedding_lr: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||
)
|
||||
qlora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
)
|
||||
orpo_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
)
|
||||
lisa_n_layers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "the number of activate layers in LISA"},
|
||||
)
|
||||
lisa_step_interval: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to switch layers in LISA"},
|
||||
)
|
||||
lisa_layers_attribute: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "path under the model to access the layers"},
|
||||
)
|
||||
curriculum_sampling: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||
)
|
||||
alternate_lr_scheduler_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||
},
|
||||
)
|
||||
chat_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Chat template converting chat messages to text"},
|
||||
)
|
||||
|
||||
# kd_ce_alpha: Optional[float] = field(
|
||||
# default=None,
|
||||
# metadata={
|
||||
# "help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
|
||||
# },
|
||||
# )
|
||||
#
|
||||
# kd_alpha: Optional[float] = field(
|
||||
# default=1.0,
|
||||
# metadata={"help": "The alpha scaling parameter for KD loss"},
|
||||
# )
|
||||
#
|
||||
# kd_temperature: Optional[float] = field(
|
||||
# default=1.0,
|
||||
# metadata={
|
||||
# "help": "the temperature parameter for KL divergence loss when using KD"
|
||||
# },
|
||||
# )
|
||||
|
||||
adam_beta3: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
|
||||
},
|
||||
)
|
||||
adam_epsilon2: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
|
||||
},
|
||||
)
|
||||
|
||||
activation_offloading: bool | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Use activation offloading with CUDA streams for training."},
|
||||
)
|
||||
|
||||
# multi-modal section
|
||||
|
||||
image_size: int | tuple[int, int] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The size of the image to resize to"},
|
||||
)
|
||||
|
||||
image_resize_algorithm: Resampling | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The algorithm to use for image resizing"},
|
||||
)
|
||||
|
||||
# end of multi-modal section
|
||||
|
||||
dion_learning_rate: float | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The learning rate for Dion"},
|
||||
)
|
||||
dion_momentum: float | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The momentum for Dion"},
|
||||
)
|
||||
dion_rank_fraction: float | None = field(
|
||||
default=None,
|
||||
)
|
||||
dion_rank_multiple_of: int | None = field(
|
||||
default=None,
|
||||
)
|
||||
@@ -1,40 +1,36 @@
|
||||
"""Module containing Dataset functionality"""
|
||||
"""
|
||||
Module containing dataset functionality.
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
|
||||
concept of middlewares to wrap each dataset. We'll use the collators later on to pad the
|
||||
datasets.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||
|
||||
# We want this to be a wrapper for an existing dataset that we have loaded
|
||||
# lets use the concept of middlewares to wrap each dataset, for example
|
||||
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
|
||||
# let's check to ensure we don't truncate an item in the middle, we'll use
|
||||
# the collators later on to pad the datasets
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class TokenizedPromptDataset(Dataset):
|
||||
"""
|
||||
Dataset that returns tokenized prompts from a stream of text files.
|
||||
Args:
|
||||
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
process_count (int): Number of processes to use for tokenizing.
|
||||
keep_in_memory (bool): Whether to keep the tokenized dataset in memory.
|
||||
"""Dataset that returns tokenized prompts from a stream of text files.
|
||||
|
||||
Args:
|
||||
prompt_tokenizer: The prompt tokenizing method for processing the data.
|
||||
dataset: Dataset with text files.
|
||||
process_count: Number of processes to use for tokenizing.
|
||||
keep_in_memory: Whether to keep the tokenized dataset in memory.
|
||||
"""
|
||||
|
||||
def __init__( # pylint: disable=super-init-not-called
|
||||
def __init__(
|
||||
self,
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: Dataset,
|
||||
process_count: Optional[int] = None,
|
||||
keep_in_memory: Optional[bool] = False,
|
||||
process_count: int | None = None,
|
||||
keep_in_memory: bool | None = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.prompt_tokenizer = prompt_tokenizer
|
||||
@@ -47,7 +43,6 @@ class TokenizedPromptDataset(Dataset):
|
||||
|
||||
def process(self, dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
|
||||
|
||||
map_kwargs = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
@@ -60,13 +55,13 @@ class TokenizedPromptDataset(Dataset):
|
||||
):
|
||||
dataset = dataset.filter(
|
||||
self.prompt_tokenizer.filter_rows,
|
||||
num_proc=num_proc,
|
||||
num_proc=self.process_count,
|
||||
desc="Strategy Filtering Rows",
|
||||
)
|
||||
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
num_proc=self.process_count,
|
||||
remove_columns=features,
|
||||
keep_in_memory=self.keep_in_memory,
|
||||
desc="Tokenizing Prompts",
|
||||
@@ -76,143 +71,17 @@ class TokenizedPromptDataset(Dataset):
|
||||
|
||||
def wrap_dataset_for_tokenized_prompt(
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: Union[Dataset, IterableDataset],
|
||||
dataset: Dataset | IterableDataset,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(dataset, IterableDataset):
|
||||
map_kwargs = {}
|
||||
if prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
features = dataset.features.keys()
|
||||
features = list(dataset.features.keys())
|
||||
return dataset.map(
|
||||
prompt_tokenizer.tokenize_prompt,
|
||||
remove_columns=features,
|
||||
**map_kwargs,
|
||||
)
|
||||
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
|
||||
|
||||
|
||||
# TODO this isn't the best since it can't interleave datasets
|
||||
class ConstantLengthDataset(IterableDataset):
|
||||
"""
|
||||
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
||||
Args:
|
||||
tokenizer (Tokenizer): The processor used for processing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
seq_length (int): Length of token sequences to return.
|
||||
"""
|
||||
|
||||
def __init__( # pylint: disable=super-init-not-called
|
||||
self,
|
||||
tokenizer,
|
||||
datasets,
|
||||
seq_length=2048,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.concat_token_id = tokenizer.eos_token_id
|
||||
self.datasets: List[IterableDataset] = datasets
|
||||
self.seq_length = seq_length
|
||||
|
||||
vocab_size = len(tokenizer.get_vocab())
|
||||
|
||||
if vocab_size <= torch.iinfo(torch.int16).max:
|
||||
self.tokens_dtype = torch.int16
|
||||
elif vocab_size <= torch.iinfo(torch.int32).max:
|
||||
self.tokens_dtype = torch.int32
|
||||
else:
|
||||
self.tokens_dtype = torch.int64
|
||||
|
||||
def __iter__(self):
|
||||
buffer = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"position_ids": [],
|
||||
}
|
||||
buffer_len = 0
|
||||
for dataset in self.datasets:
|
||||
idx = 0
|
||||
iterator = iter(dataset)
|
||||
more_examples = True
|
||||
while more_examples:
|
||||
try:
|
||||
example = next(iterator)
|
||||
idx += 1
|
||||
except StopIteration:
|
||||
more_examples = False
|
||||
example = None
|
||||
|
||||
add_concat_token = False
|
||||
if example:
|
||||
example_len = len(example["input_ids"])
|
||||
add_concat_token = example["input_ids"][-1] != self.concat_token_id
|
||||
else:
|
||||
example_len = 0
|
||||
|
||||
if not example_len or (
|
||||
buffer_len + int(add_concat_token) + example_len > self.seq_length
|
||||
):
|
||||
if buffer["input_ids"]:
|
||||
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
|
||||
: self.seq_length
|
||||
]
|
||||
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
|
||||
: self.seq_length
|
||||
]
|
||||
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
|
||||
: self.seq_length
|
||||
]
|
||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||
if labels.size() == input_ids.size() and (
|
||||
attention_mask.size() == input_ids.size()
|
||||
):
|
||||
yield {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
else:
|
||||
LOG.warning(
|
||||
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
|
||||
)
|
||||
buffer = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"position_ids": [],
|
||||
}
|
||||
buffer_len = 0
|
||||
idx = 1
|
||||
|
||||
if example:
|
||||
# FIXME
|
||||
# just going to drop data points that are too long
|
||||
if len(example["input_ids"]) <= self.seq_length:
|
||||
input_ids = example["input_ids"]
|
||||
attention_mask = example["attention_mask"]
|
||||
labels = example["labels"]
|
||||
|
||||
if add_concat_token:
|
||||
input_ids.append(self.concat_token_id)
|
||||
attention_mask.append(1)
|
||||
labels.append(self.concat_token_id)
|
||||
|
||||
input_ids_with_concat = torch.tensor(
|
||||
input_ids, dtype=self.tokens_dtype
|
||||
)
|
||||
attention_mask_with_concat = torch.tensor(
|
||||
[idx * m for m in attention_mask], dtype=torch.int16
|
||||
)
|
||||
labels_with_concat = torch.tensor(
|
||||
labels, dtype=self.tokens_dtype
|
||||
)
|
||||
position_ids = torch.arange(
|
||||
len(input_ids), dtype=self.tokens_dtype
|
||||
)
|
||||
|
||||
buffer["input_ids"].append(input_ids_with_concat)
|
||||
buffer["attention_mask"].append(attention_mask_with_concat)
|
||||
buffer["labels"].append(labels_with_concat)
|
||||
buffer["position_ids"].append(position_ids)
|
||||
buffer_len += len(input_ids)
|
||||
|
||||
@@ -7,7 +7,6 @@ from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
@@ -18,6 +17,7 @@ from axolotl.train import (
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
@@ -81,7 +81,7 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
|
||||
model, tokenizer, _, processor = setup_model_and_tokenizer(cfg)
|
||||
|
||||
# Get datasets
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
eval_dataset = dataset_meta.eval_dataset
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
@@ -22,17 +22,20 @@ from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import importlib
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
||||
|
||||
from peft import PeftModel
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from transformers import PreTrainedModel, Trainer
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.common.datasets import TrainDatasetMeta
|
||||
@@ -73,8 +76,8 @@ class BasePlugin:
|
||||
def __init__(self):
|
||||
"""Initializes the BasePlugin."""
|
||||
|
||||
def register(self, cfg: DictDefault): # pylint: disable=unused-argument
|
||||
"""Registers the plugin with the given configuration.
|
||||
def register(self, cfg: dict):
|
||||
"""Registers the plugin with the given configuration as an unparsed dict.
|
||||
|
||||
Args:
|
||||
cfg: The configuration for the plugin.
|
||||
@@ -83,6 +86,11 @@ class BasePlugin:
|
||||
def get_input_args(self) -> str | None:
|
||||
"""Returns a pydantic model for the plugin's input arguments."""
|
||||
|
||||
def get_training_args_mixin(self) -> str | None:
|
||||
"""
|
||||
Returns a dataclass model for the plugin's training arguments.
|
||||
"""
|
||||
|
||||
def load_datasets(
|
||||
self, cfg: DictDefault, preprocess: bool = False
|
||||
) -> Union["TrainDatasetMeta", None]:
|
||||
@@ -96,14 +104,13 @@ class BasePlugin:
|
||||
dataset_meta: The metadata for the training dataset.
|
||||
"""
|
||||
|
||||
def pre_model_load(self, cfg: DictDefault): # pylint: disable=unused-argument
|
||||
def pre_model_load(self, cfg: DictDefault):
|
||||
"""Performs actions before the model is loaded.
|
||||
|
||||
Args:
|
||||
cfg: The configuration for the plugin.
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def post_model_build(self, cfg: DictDefault, model: PreTrainedModel):
|
||||
"""Performs actions after the model is built/loaded, but before any adapters are applied.
|
||||
|
||||
@@ -111,7 +118,6 @@ class BasePlugin:
|
||||
cfg: The configuration for the plugin.
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel):
|
||||
"""Performs actions before LoRA weights are loaded.
|
||||
|
||||
@@ -120,7 +126,6 @@ class BasePlugin:
|
||||
model: The loaded model.
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||
"""Performs actions after LoRA weights are loaded.
|
||||
|
||||
@@ -129,7 +134,6 @@ class BasePlugin:
|
||||
model: The loaded model.
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||
"""Performs actions after the model is loaded.
|
||||
|
||||
@@ -138,8 +142,7 @@ class BasePlugin:
|
||||
model: The loaded model.
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
|
||||
def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:
|
||||
"""Returns a custom class for the trainer.
|
||||
|
||||
Args:
|
||||
@@ -149,7 +152,6 @@ class BasePlugin:
|
||||
The first non-`None` trainer class returned by a plugin.
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
|
||||
"""Performs actions after the trainer is created.
|
||||
|
||||
@@ -158,7 +160,29 @@ class BasePlugin:
|
||||
trainer: The trainer object for training.
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def get_training_args(self, cfg: DictDefault):
|
||||
"""
|
||||
Returns custom training arguments to set on TrainingArgs.
|
||||
|
||||
Args:
|
||||
cfg: The global axolotl configuration.
|
||||
|
||||
Returns:
|
||||
object: dict containing the training arguments.
|
||||
"""
|
||||
|
||||
def get_collator_cls_and_kwargs(self, cfg: DictDefault, is_eval: bool = False):
|
||||
"""
|
||||
Returns a custom class for the collator.
|
||||
|
||||
Args:
|
||||
cfg: The global axolotl configuration.
|
||||
is_eval: Whether this is an eval split.
|
||||
|
||||
Returns:
|
||||
class: The class for the collator.
|
||||
"""
|
||||
|
||||
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
|
||||
"""Creates and returns an optimizer for training.
|
||||
|
||||
@@ -170,7 +194,6 @@ class BasePlugin:
|
||||
The created optimizer.
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def create_lr_scheduler(
|
||||
self,
|
||||
cfg: DictDefault,
|
||||
@@ -190,7 +213,6 @@ class BasePlugin:
|
||||
The created learning rate scheduler.
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def add_callbacks_pre_trainer(
|
||||
self, cfg: DictDefault, model: PreTrainedModel
|
||||
) -> list[Callable]:
|
||||
@@ -205,7 +227,6 @@ class BasePlugin:
|
||||
"""
|
||||
return []
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def add_callbacks_post_trainer(
|
||||
self, cfg: DictDefault, trainer: Trainer
|
||||
) -> list[Callable]:
|
||||
@@ -221,7 +242,6 @@ class BasePlugin:
|
||||
"""
|
||||
return []
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||
"""Performs actions after training is complete.
|
||||
|
||||
@@ -230,7 +250,7 @@ class BasePlugin:
|
||||
model: The loaded model.
|
||||
"""
|
||||
|
||||
def post_train_unload(self, cfg: DictDefault): # pylint: disable=unused-argument
|
||||
def post_train_unload(self, cfg: DictDefault):
|
||||
"""Performs actions after training is complete and the model is unloaded.
|
||||
|
||||
Args:
|
||||
@@ -337,8 +357,11 @@ class PluginManager:
|
||||
plugin = load_plugin(plugin_name)
|
||||
self.plugins[plugin_name] = plugin
|
||||
LOG.info(f"Plugin loaded successfully: {plugin_name}")
|
||||
except ImportError:
|
||||
except ImportError as exc:
|
||||
LOG.error(f"Failed to load plugin: {plugin_name}")
|
||||
# print stacktrace
|
||||
traceback.print_exc()
|
||||
print(f"Error: {exc}")
|
||||
|
||||
def get_input_args(self) -> list[str]:
|
||||
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
|
||||
@@ -353,6 +376,20 @@ class PluginManager:
|
||||
input_args.append(input_args_from_plugin)
|
||||
return input_args
|
||||
|
||||
def get_training_args_mixin(self):
|
||||
"""
|
||||
Returns a list of dataclasses for all registered plugins' training args mixins'
|
||||
|
||||
Returns:
|
||||
list[str]: A list of dataclsses
|
||||
"""
|
||||
training_args = []
|
||||
for plugin in self.plugins.values():
|
||||
training_args_from_plugin = plugin.get_training_args_mixin()
|
||||
if training_args_from_plugin is not None:
|
||||
training_args.append(training_args_from_plugin)
|
||||
return training_args
|
||||
|
||||
def load_datasets(
|
||||
self, cfg: DictDefault, preprocess: bool = False
|
||||
) -> Union["TrainDatasetMeta", None]:
|
||||
@@ -442,6 +479,42 @@ class PluginManager:
|
||||
return trainer_cls
|
||||
return None
|
||||
|
||||
def get_training_args(self, cfg):
|
||||
"""
|
||||
Calls the get_training_args method of all registered plugins and returns the combined training arguments.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
|
||||
Returns:
|
||||
object: The training arguments
|
||||
"""
|
||||
training_args_kwargs = {}
|
||||
for plugin in self.plugins.values():
|
||||
training_args = plugin.get_training_args(cfg)
|
||||
if training_args is not None:
|
||||
training_args_kwargs.update(training_args)
|
||||
|
||||
return training_args_kwargs
|
||||
|
||||
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
|
||||
"""
|
||||
Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
is_eval (bool): Whether this is an eval split.
|
||||
|
||||
Returns:
|
||||
object: The collator class, or None if none was found.
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
collator = plugin.get_collator_cls_and_kwargs(cfg, is_eval=is_eval)
|
||||
if collator is not None:
|
||||
collator_cls, collator_kwargs = collator
|
||||
return collator_cls, collator_kwargs
|
||||
return None
|
||||
|
||||
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
|
||||
"""Calls the `post_trainer_create` method of all registered plugins.
|
||||
|
||||
@@ -557,3 +630,24 @@ class BaseOptimizerFactory:
|
||||
self, opt_model, training_args, **optimizer_kwargs
|
||||
) -> Optimizer | None:
|
||||
pass
|
||||
|
||||
# duplicated from transformers
|
||||
def get_decay_parameter_names(self, model) -> list[str]:
|
||||
"""
|
||||
Get all parameter names that weight decay will be applied to.
|
||||
|
||||
This function filters out parameters in two ways:
|
||||
1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
|
||||
2. By parameter name patterns (containing 'bias', or variation of 'norm')
|
||||
"""
|
||||
forbidden_name_patterns = [
|
||||
r"bias",
|
||||
r"layernorm",
|
||||
r"rmsnorm",
|
||||
r"(?:^|\.)norm(?:$|\.)",
|
||||
r"_norm(?:$|\.)",
|
||||
]
|
||||
decay_parameters = get_parameter_names(
|
||||
model, [nn.LayerNorm], forbidden_name_patterns
|
||||
)
|
||||
return decay_parameters
|
||||
|
||||
@@ -16,12 +16,12 @@ Module to handle merging the plugins' input arguments with the base configuratio
|
||||
This was moved here to prevent circular imports.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
from axolotl.utils.schemas.config import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||
)
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||
|
||||
|
||||
def merge_input_args():
|
||||
@@ -50,14 +50,44 @@ def merge_input_args():
|
||||
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
|
||||
|
||||
namespace: Dict[Any, Any] = {}
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
dynamic_input, globals(), namespace
|
||||
)
|
||||
AxolotlInputConfig = namespace[ # pylint: disable=invalid-name
|
||||
"AxolotlInputConfig"
|
||||
]
|
||||
AxolotlConfigWCapabilities = namespace[ # pylint: disable=invalid-name
|
||||
"AxolotlConfigWCapabilities"
|
||||
]
|
||||
exec(dynamic_input, globals(), namespace) # nosec B102
|
||||
AxolotlInputConfig = namespace["AxolotlInputConfig"]
|
||||
AxolotlConfigWCapabilities = namespace["AxolotlConfigWCapabilities"]
|
||||
return AxolotlConfigWCapabilities, AxolotlInputConfig
|
||||
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
|
||||
|
||||
|
||||
def merge_training_args() -> Type:
|
||||
"""
|
||||
Merges training arguments from registered plugins with the base TrainingArguments.
|
||||
|
||||
This function retrieves the training arguments from registered plugins using the PluginManager.
|
||||
It then dynamically creates new classes, AxolotlTrainingMixins,
|
||||
that inherit from the base configurations and include the training arguments from the plugins.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the newly created classes, AxolotlTrainingMixins.
|
||||
"""
|
||||
|
||||
from axolotl.core.training_args_base import (
|
||||
AxolotlTrainingMixins as AxolotlTrainingMixinsBase,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
training_args_mixins: List[str] = plugin_manager.get_training_args_mixin()
|
||||
mixin_classes = []
|
||||
dynamic_input = ""
|
||||
for plugin_args in training_args_mixins:
|
||||
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
|
||||
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
|
||||
mixin_classes.append(plugin_cls)
|
||||
if dynamic_input:
|
||||
dynamic_input += f"class AxolotlTrainingMixins(AxolotlTrainingMixinsBase, {', '.join(mixin_classes)}):\n pass\n"
|
||||
|
||||
namespace: Dict[Any, Any] = {}
|
||||
local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase}
|
||||
exec(dynamic_input, {**globals(), **local_vars}, namespace) # nosec B102
|
||||
AxolotlTrainingMixins = namespace["AxolotlTrainingMixins"]
|
||||
return AxolotlTrainingMixins
|
||||
return AxolotlTrainingMixinsBase
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -31,27 +31,55 @@ plugins:
|
||||
|
||||
## Supported Models
|
||||
|
||||
- llama
|
||||
- llama4
|
||||
- llama4_text
|
||||
- mllama
|
||||
- phi3
|
||||
- apertus
|
||||
- arcee
|
||||
- cohere
|
||||
- cohere2
|
||||
- deepseek_v3
|
||||
- gemma
|
||||
- gemma2
|
||||
- gemma3
|
||||
- gemma3_text
|
||||
- gemma3n
|
||||
- gemma3n_text
|
||||
- glm
|
||||
- glm4
|
||||
- glm4_moe
|
||||
- glm4v
|
||||
- glm4v_moe
|
||||
- gpt_oss
|
||||
- granite
|
||||
- granitemoe
|
||||
- granitemoeshared
|
||||
- granitemoehybrid
|
||||
- hunyuan_v1_dense
|
||||
- hunyuan_v1_moe
|
||||
- lfm2
|
||||
- lfm2_moe
|
||||
- lfm2_vl
|
||||
- llama
|
||||
- llama4
|
||||
- llama4_text
|
||||
- llava
|
||||
- mistral
|
||||
- mistral3
|
||||
- mixtral
|
||||
- mllama
|
||||
- phi
|
||||
- phi3
|
||||
- phi4_multimodal
|
||||
- qwen2
|
||||
- qwen2_moe
|
||||
- qwen2_vl
|
||||
- qwen2_moe
|
||||
- qwen2_5_vl
|
||||
- qwen3
|
||||
- qwen3_moe
|
||||
- cohere
|
||||
- cohere2
|
||||
- glm
|
||||
- glm4
|
||||
- qwen3_vl
|
||||
- qwen3_vl_moe
|
||||
- qwen3_next
|
||||
- smollm3
|
||||
- seed_oss
|
||||
- voxtral
|
||||
|
||||
## Citation
|
||||
|
||||
|
||||
@@ -18,21 +18,24 @@ Module for the Plugin for Cut Cross Entropy integration with Axolotl.
|
||||
Cut Cross Entropy is an optimized implementation of cross entropy loss
|
||||
from Apple's ML team.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils import get_pytorch_version
|
||||
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||
from .args import CutCrossEntropyArgs as CutCrossEntropyArgs
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"`'
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"`'
|
||||
)
|
||||
|
||||
|
||||
@@ -64,16 +67,29 @@ class CutCrossEntropyPlugin(BasePlugin):
|
||||
"cut_cross_entropy.transformers"
|
||||
)
|
||||
if cce_spec_transformers is None:
|
||||
raise ImportError(_CCE_INSTALL_MESSAGE)
|
||||
raise ImportError(
|
||||
"Transformers support is not installed. " + _CCE_INSTALL_MESSAGE
|
||||
)
|
||||
|
||||
# Check if Axolotl's cce fork is installed
|
||||
try:
|
||||
from cut_cross_entropy.transformers.patch import AXOLOTL_CCE_FORK
|
||||
|
||||
if not AXOLOTL_CCE_FORK:
|
||||
raise ImportError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Axolotl's fork of cut_cross_entropy is not installed. "
|
||||
+ _CCE_INSTALL_MESSAGE
|
||||
) from e
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""Apply cut cross entropy before model loading if enabled."""
|
||||
if cfg.cut_cross_entropy:
|
||||
self._check_requirements()
|
||||
self.patch_llama_like(cfg.model_config_type)
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
|
||||
cce_patch,
|
||||
)
|
||||
from cut_cross_entropy.transformers.patch import cce_patch
|
||||
|
||||
LOG.info(
|
||||
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
|
||||
@@ -81,3 +97,44 @@ class CutCrossEntropyPlugin(BasePlugin):
|
||||
|
||||
# The patch checks model_type internally
|
||||
cce_patch(cfg.model_config_type)
|
||||
|
||||
def patch_llama_like(
|
||||
self,
|
||||
model_type: str,
|
||||
) -> None:
|
||||
"""
|
||||
Generic patch for model architectures with causal lm similar to llama
|
||||
"""
|
||||
from cut_cross_entropy.transformers.patch import PATCH_FNS
|
||||
|
||||
def patch_generic(maybe_model, patch_options, model_type: str):
|
||||
import cut_cross_entropy.transformers.llama
|
||||
from cut_cross_entropy.transformers.llama import cce_forward
|
||||
|
||||
try:
|
||||
# Dynamically import the module and CausalLM class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
||||
module = __import__(
|
||||
module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]
|
||||
)
|
||||
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
||||
|
||||
cut_cross_entropy.transformers.llama._PATCH_OPTS = patch_options
|
||||
|
||||
model_cls.forward = cce_forward
|
||||
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise RuntimeError(
|
||||
f"Could not import ForCausalLM class for model_type: {model_type}. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
|
||||
if model_type not in PATCH_FNS:
|
||||
LOG.warning_once(
|
||||
"Setting up generic cce patch for model type: %s", model_type
|
||||
)
|
||||
LOG.warning_once(
|
||||
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
|
||||
)
|
||||
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
Module for handling Cut Cross Entropy input arguments.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
@@ -41,3 +42,13 @@ class CutCrossEntropyArgs(BaseModel):
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_chunked_cross_entropy_not_set(cls, data):
|
||||
if data.get("chunked_cross_entropy"):
|
||||
raise ValueError(
|
||||
"Cut Cross Entropy does not support chunked cross entropy. "
|
||||
"Please set `chunked_cross_entropy` to `False` or disable Cut Cross Entropy."
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -1,191 +0,0 @@
|
||||
"""Cohere and Cohere2 CCE patch."""
|
||||
|
||||
# This patch is based off transformers 4.50.0.
|
||||
# It patches the forward function for CohereForCausalLM and Cohere2ForCausalLM.
|
||||
# It scales the hidden states by the logit scale in advance instead of the logits as the
|
||||
# operation is done internally and should be mathematically equivalent.
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>> from transformers import AutoTokenizer, CohereForCausalLM
|
||||
|
||||
>> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||
>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||
|
||||
>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>> # Generate
|
||||
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
# scale hidden_states by logit_scale in-place of logits
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :] * self.logit_scale,
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
logits = logits * self.logit_scale # main diff from Llama
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_cohere(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.cohere import modeling_cohere
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_cohere.CohereForCausalLM
|
||||
), f"Expected a CohereForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_cohere.CohereForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_cohere2(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.cohere2 import modeling_cohere2
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_cohere2.Cohere2ForCausalLM
|
||||
), f"Expected a Cohere2ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_cohere2.Cohere2ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,165 +0,0 @@
|
||||
"""Gemma CCE patch"""
|
||||
|
||||
# This patch is based off transformers 4.50.0.
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.gemma.modeling_gemma import (
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
||||
|
||||
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
||||
|
||||
>>> prompt = "What is your favorite condiment?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_gemma(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma.GemmaForCausalLM
|
||||
), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma.GemmaForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,447 +0,0 @@
|
||||
"""Gemma2 and Gemma3 (text and multimodal) CCE patch."""
|
||||
|
||||
# Implementation originally adapted from https://github.com/apple/ml-cross-entropy/pull/29
|
||||
# and updated for transformers 4.50.0.
|
||||
# This is a modified version of the patch that allows for deferred logits calculation for gemma3 and works
|
||||
# with both gemma3 (text and multimodal) models.
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache, HybridCache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.gemma3.modeling_gemma3 import (
|
||||
Gemma3CausalLMOutputWithPast,
|
||||
logger,
|
||||
)
|
||||
from transformers.utils import (
|
||||
is_torchdynamo_compiling,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Gemma3ForCausalLM
|
||||
|
||||
>>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
||||
|
||||
>>> prompt = "What is your favorite condiment?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**loss_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||
**loss_kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if self.config.final_logit_softcapping is not None:
|
||||
logits = logits / self.config.final_logit_softcapping
|
||||
logits = torch.tanh(logits)
|
||||
logits = logits * self.config.final_logit_softcapping
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||
|
||||
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
|
||||
|
||||
>>> prompt = "answer en Where is the cow standing?"
|
||||
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_length=30)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"answer en Where is the cow standing?\nbeach"
|
||||
```"""
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_index
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
llm_input_ids = input_ids # type: ignore
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = (
|
||||
past_key_values.get_seq_length() if past_key_values is not None else 0 # type: ignore
|
||||
)
|
||||
cache_position = torch.arange( # type: ignore
|
||||
past_seen_tokens,
|
||||
past_seen_tokens + inputs_embeds.shape[1],
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
# Merge text and images
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(
|
||||
self.config.image_token_index,
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||
-1
|
||||
)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
|
||||
if (
|
||||
not is_torchdynamo_compiling()
|
||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||
):
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
raise ValueError(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
||||
"tokens from image embeddings."
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
||||
|
||||
# mask out pad-token-ids in labels for BC
|
||||
if labels is not None and self.pad_token_id in labels:
|
||||
logger.warning_once(
|
||||
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
||||
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
||||
)
|
||||
labels = torch.where( # type: ignore
|
||||
input_ids == self.pad_token_id, self.config.ignore_index, labels
|
||||
)
|
||||
|
||||
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
past_key_values,
|
||||
cache_position,
|
||||
inputs_embeds,
|
||||
is_training,
|
||||
)
|
||||
outputs = self.language_model(
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||
**lm_kwargs,
|
||||
)
|
||||
else:
|
||||
logits = hidden_states
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
shift_logits = logits[..., :-1, :]
|
||||
shift_labels = labels[..., 1:]
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(
|
||||
logits.device
|
||||
)
|
||||
shift_logits = shift_logits[
|
||||
shift_attention_mask.to(logits.device) != 0
|
||||
].contiguous()
|
||||
shift_labels = shift_labels[
|
||||
shift_attention_mask.to(shift_labels.device) != 0
|
||||
].contiguous()
|
||||
else:
|
||||
shift_logits = shift_logits.contiguous()
|
||||
shift_labels = shift_labels.contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
|
||||
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
||||
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
||||
loss = loss_fct(flat_logits, flat_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Gemma3CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def patch_gemma2(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma2 import modeling_gemma2
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma2.Gemma2ForCausalLM
|
||||
), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_gemma3_text(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma3 import modeling_gemma3
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma3.Gemma3ForCausalLM
|
||||
), f"Expected a Gemma3ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_gemma3(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma3 import modeling_gemma3
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma3.Gemma3ForConditionalGeneration
|
||||
), f"Expected a Gemma3ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma3.Gemma3ForConditionalGeneration.forward = cce_forward_multimodal
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,57 +0,0 @@
|
||||
"""GLM 4 patch. GLM family inherits from Llama."""
|
||||
|
||||
from types import MethodType
|
||||
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
)
|
||||
|
||||
|
||||
def patch_glm(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
|
||||
# Set the _PATCH_OPTS in the llama patch file
|
||||
import cut_cross_entropy.transformers.llama as llama_patch
|
||||
|
||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||
|
||||
from cut_cross_entropy.transformers.llama import cce_forward
|
||||
from transformers.models.glm import modeling_glm
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_glm.GlmForCausalLM
|
||||
), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_glm.GlmForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_glm4(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
|
||||
# Set the _PATCH_OPTS in the llama patch file
|
||||
import cut_cross_entropy.transformers.llama as llama_patch
|
||||
|
||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||
|
||||
from cut_cross_entropy.transformers.llama import cce_forward
|
||||
from transformers.models.glm4 import modeling_glm4
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_glm4.Glm4ForCausalLM
|
||||
), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_glm4.Glm4ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,164 +0,0 @@
|
||||
"""Llama CCE patch. Adapted from transformers v4.51.2"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
from transformers.utils.generic import can_return_tuple
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states is None")
|
||||
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_llama(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
"""Patch Llama for CCE."""
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.llama import modeling_llama
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_llama.LlamaForCausalLM
|
||||
), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_llama.LlamaForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,401 +0,0 @@
|
||||
"""Llama4 CCE patch. Adapted from transformers 4.51.0."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.llama4.modeling_llama4 import (
|
||||
Llama4CausalLMOutputWithPast,
|
||||
)
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Llama4ForCausalLM
|
||||
|
||||
>>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None, # type: ignore
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: torch.Tensor | None = None,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Llama4CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
|
||||
|
||||
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer
|
||||
if vision_feature_layer is not None
|
||||
else self.config.vision_config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids) # type: ignore
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
original_inputs_embeds_shape = inputs_embeds.shape # type: ignore
|
||||
|
||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
final_mask = special_image_mask.to(inputs_embeds.device) # type: ignore
|
||||
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore
|
||||
|
||||
final_mask_1d = final_mask[..., 0].reshape(-1)
|
||||
num_tokens_to_fill = final_mask_1d.sum()
|
||||
|
||||
if num_tokens_to_fill != projected_vision_flat.size(0):
|
||||
raise ValueError(
|
||||
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
||||
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
|
||||
)
|
||||
|
||||
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
expanded_mask, projected_vision_flat
|
||||
) # type: ignore
|
||||
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) # type: ignore
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
# TODO: check if need to handle attention_mask
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**lm_kwargs,
|
||||
)
|
||||
else:
|
||||
logits = hidden_states
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
||||
logits.device
|
||||
)
|
||||
shift_logits = logits[..., :-1, :][
|
||||
shift_attention_mask.to(logits.device) != 0
|
||||
].contiguous()
|
||||
shift_labels = labels[..., 1:][
|
||||
shift_attention_mask.to(labels.device) != 0
|
||||
].contiguous()
|
||||
else:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1).to(shift_logits.device),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Llama4CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits, # type: ignore # TODO: check if need to create dummy logits
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def patch_llama4_text(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.llama4 import modeling_llama4
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_llama4.Llama4ForCausalLM
|
||||
), f"Expected a Llama4ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
|
||||
return maybe_model
|
||||
|
||||
setattr(
|
||||
modeling_llama4.Llama4ForCausalLM,
|
||||
"forward",
|
||||
cce_forward,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def patch_llama4(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.llama4 import modeling_llama4
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_llama4.Llama4ForConditionalGeneration
|
||||
), f"Expected a Llama4ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the language model
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
setattr(
|
||||
modeling_llama4.Llama4ForConditionalGeneration,
|
||||
"forward",
|
||||
cce_forward_multimodal,
|
||||
)
|
||||
|
||||
# patch the causal language model
|
||||
setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward)
|
||||
return None
|
||||
@@ -1,384 +0,0 @@
|
||||
"""Mistral and Mistral3 CCE patch."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.mistral3.modeling_mistral3 import (
|
||||
Mistral3CausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
is_torchdynamo_compiling,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] | None = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
||||
|
||||
>>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: torch.Tensor | None = None,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
|
||||
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||
|
||||
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is the image?The image depicts two cats lying on a pink blanket."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer
|
||||
if vision_feature_layer is not None
|
||||
else self.config.vision_feature_layer
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
if (
|
||||
not is_torchdynamo_compiling()
|
||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||
):
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**lm_kwargs,
|
||||
)
|
||||
else:
|
||||
logits = hidden_states
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
||||
logits.device
|
||||
)
|
||||
shift_logits = logits[..., :-1, :][
|
||||
shift_attention_mask.to(logits.device) != 0
|
||||
].contiguous()
|
||||
shift_labels = labels[..., 1:][
|
||||
shift_attention_mask.to(labels.device) != 0
|
||||
].contiguous()
|
||||
else:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1).to(shift_logits.device),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Mistral3CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def patch_mistral(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_mistral.MistralForCausalLM
|
||||
), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_mistral3(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
from transformers.models.mistral3 import modeling_mistral3
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_mistral3.Mistral3ForConditionalGeneration
|
||||
), f"Expected a Mistral3ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
modeling_mistral3.Mistral3ForConditionalGeneration.forward = cce_forward_multimodal
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,366 +0,0 @@
|
||||
"""Mllama CCE patch."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.mllama.modeling_mllama import (
|
||||
_prepare_cross_attention_mask,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cross_attention_states: Optional[torch.LongTensor] = None,
|
||||
cross_attention_mask: Optional[torch.LongTensor] = None,
|
||||
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MllamaForCausalLM
|
||||
|
||||
>>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
|
||||
|
||||
>>> prompt = "If I had to write a haiku, it would be:"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
|
||||
>>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
>>> print(result)
|
||||
If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
|
||||
I love the idea of snowflakes gently falling, each one
|
||||
```
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
cross_attention_states=cross_attention_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**loss_kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
aspect_ratio_mask: Optional[torch.Tensor] = None,
|
||||
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, MllamaForConditionalGeneration
|
||||
|
||||
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
|
||||
>>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
|
||||
>>> processor = AutoProcessor.from_pretrained(checkpoint)
|
||||
|
||||
>>> prompt = "<|image|>If I had to write a haiku for this one"
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> output = model.generate(**inputs, max_new_tokens=15)
|
||||
|
||||
>>> prompt_len = inputs.input_ids.shape[-1]
|
||||
>>> generated_ids = output[:, prompt_len:]
|
||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
>>> print(generated_text)
|
||||
[', it would be:.\\nA stop sign in Chinatown.\\n']
|
||||
```
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if pixel_values is not None and cross_attention_states is not None:
|
||||
raise ValueError(
|
||||
"`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
if aspect_ratio_ids is None:
|
||||
raise ValueError(
|
||||
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
||||
)
|
||||
# get vision tokens from vision model
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
aspect_ratio_ids=aspect_ratio_ids,
|
||||
aspect_ratio_mask=aspect_ratio_mask,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
cross_attention_states = vision_outputs[0]
|
||||
cross_attention_states = self.multi_modal_projector(
|
||||
cross_attention_states
|
||||
).reshape(
|
||||
-1, cross_attention_states.shape[-2], self.hidden_size # type: ignore
|
||||
)
|
||||
|
||||
if cross_attention_mask is not None:
|
||||
cross_attention_mask, full_text_row_masked_out_mask = (
|
||||
_prepare_cross_attention_mask(
|
||||
cross_attention_mask,
|
||||
num_vision_tokens=self.vision_model.num_patches,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
)
|
||||
else:
|
||||
full_text_row_masked_out_mask = None
|
||||
|
||||
if cross_attention_mask is not None and cache_position is not None:
|
||||
cross_attention_mask = cross_attention_mask[:, :, cache_position]
|
||||
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
|
||||
:, :, cache_position
|
||||
]
|
||||
|
||||
outputs = self.language_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cross_attention_states=cross_attention_states,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**loss_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**loss_kwargs,
|
||||
)
|
||||
else:
|
||||
# Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
|
||||
logits = hidden_states
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (loss,) + outputs if loss is not None else outputs
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=outputs.logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_mllama(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.mllama import modeling_mllama
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_mllama.MllamaForConditionalGeneration
|
||||
), f"Expected a MllamaForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the language model
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
modeling_mllama.MllamaForConditionalGeneration.forward = cce_forward_multimodal
|
||||
|
||||
# patch the causal language model
|
||||
modeling_mllama.MllamaForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,126 +0,0 @@
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
"""Cut Cross Entropy patcher"""
|
||||
|
||||
import transformers
|
||||
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
|
||||
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
|
||||
from cut_cross_entropy.transformers.phi3 import patch_phi3
|
||||
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
|
||||
patch_cohere,
|
||||
patch_cohere2,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
|
||||
patch_gemma2,
|
||||
patch_gemma3,
|
||||
patch_gemma3_text,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import (
|
||||
patch_glm,
|
||||
patch_glm4,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
|
||||
patch_llama,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
|
||||
patch_llama4,
|
||||
patch_llama4_text,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
|
||||
patch_mistral,
|
||||
patch_mistral3,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2 import (
|
||||
patch_qwen2,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_5_vl import (
|
||||
patch_qwen2_5_vl,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_moe import (
|
||||
patch_qwen2_moe,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_vl import (
|
||||
patch_qwen2_vl,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3 import patch_qwen3
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3_moe import (
|
||||
patch_qwen3_moe,
|
||||
)
|
||||
|
||||
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
|
||||
"llama": patch_llama,
|
||||
"llama4": patch_llama4,
|
||||
"llama4_text": patch_llama4_text,
|
||||
"mllama": patch_mllama,
|
||||
"phi3": patch_phi3,
|
||||
"gemma": patch_gemma,
|
||||
"gemma2": patch_gemma2,
|
||||
"gemma3": patch_gemma3,
|
||||
"gemma3_text": patch_gemma3_text,
|
||||
"mistral": patch_mistral,
|
||||
"mistral3": patch_mistral3,
|
||||
"qwen2": patch_qwen2,
|
||||
"qwen2_moe": patch_qwen2_moe,
|
||||
"qwen2_vl": patch_qwen2_vl,
|
||||
"qwen2_5_vl": patch_qwen2_5_vl,
|
||||
"qwen3": patch_qwen3,
|
||||
"qwen3_moe": patch_qwen3_moe,
|
||||
"cohere": patch_cohere,
|
||||
"cohere2": patch_cohere2,
|
||||
"glm": patch_glm,
|
||||
"glm4": patch_glm4,
|
||||
}
|
||||
|
||||
|
||||
def cce_patch(
|
||||
model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig,
|
||||
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
|
||||
reduction: str = "mean",
|
||||
filter_eps: float | str | None = "auto",
|
||||
accum_e_fp32: bool = False,
|
||||
accum_c_fp32: bool = False,
|
||||
filter_e_grad: bool = True,
|
||||
filter_c_grad: bool = True,
|
||||
train_only: bool = False,
|
||||
) -> TransformersModelT | None:
|
||||
if isinstance(impl, LinearCrossEntropyImpl):
|
||||
impl = impl.name.lower()
|
||||
|
||||
if impl not in (v.name.lower() for v in LinearCrossEntropyImpl):
|
||||
raise ValueError(f"Unknown {impl=}")
|
||||
|
||||
if isinstance(model_type_or_model, transformers.PreTrainedModel):
|
||||
if hasattr(model_type_or_model, "config"):
|
||||
model_type = getattr(
|
||||
getattr(model_type_or_model, "config", None), "model_type", None
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"model_type_or_model is a PreTrainedModel but does not have a config attribute"
|
||||
)
|
||||
elif isinstance(model_type_or_model, transformers.PretrainedConfig):
|
||||
model_type = model_type_or_model.model_type
|
||||
else:
|
||||
model_type = model_type_or_model
|
||||
|
||||
patch_options = PatchOptions(
|
||||
impl=impl,
|
||||
reduction=reduction,
|
||||
filter_eps=filter_eps,
|
||||
accum_e_fp32=accum_e_fp32,
|
||||
accum_c_fp32=accum_c_fp32,
|
||||
filter_e_grad=filter_e_grad,
|
||||
filter_c_grad=filter_c_grad,
|
||||
train_only=train_only,
|
||||
)
|
||||
|
||||
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
|
||||
return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type](
|
||||
model_type_or_model, patch_options
|
||||
)
|
||||
|
||||
raise RuntimeError(f"Unknown model type {model_type}")
|
||||
@@ -1,37 +0,0 @@
|
||||
"""Qwen2 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen2(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
|
||||
# Set the _PATCH_OPTS in the llama patch file
|
||||
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
|
||||
|
||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
|
||||
cce_forward,
|
||||
)
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen2.Qwen2ForCausalLM
|
||||
), f"Expected a Qwen2ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen2.Qwen2ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,246 +0,0 @@
|
||||
"""Qwen2.5 VL CCE patch. Adapted from transformers v4.51.2"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLCausalLMOutputWithPast,
|
||||
)
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
|
||||
>>> messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.image_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
image_mask = mask_expanded.to(inputs_embeds.device)
|
||||
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.video_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
video_mask = mask_expanded.to(inputs_embeds.device)
|
||||
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||
# calculate RoPE index once per generation in the pre-fill stage only
|
||||
if (
|
||||
(cache_position is not None and cache_position[0] == 0)
|
||||
or self.rope_deltas is None
|
||||
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
|
||||
):
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids,
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
second_per_grid_ts,
|
||||
attention_mask,
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
else:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
delta = (
|
||||
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
||||
if cache_position is not None
|
||||
else 0
|
||||
)
|
||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
|
||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
|
||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
|
||||
position_ids = position_ids.add(delta) # type: ignore
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=None,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = None
|
||||
loss = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Qwen2_5_VLCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
rope_deltas=self.rope_deltas,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen2_5_vl(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
|
||||
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration
|
||||
), f"Expected a Qwen2_5_VLForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = (
|
||||
cce_forward_multimodal
|
||||
)
|
||||
return None
|
||||
@@ -1,178 +0,0 @@
|
||||
"""Qwen2 MoE CCE patch. Adapted from transformers v4.51.2"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
||||
MoeCausalLMOutputWithPast,
|
||||
MoeModelOutputWithPast,
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
from transformers.utils.generic import can_return_tuple
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
) -> MoeCausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM
|
||||
|
||||
>>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_router_logits = (
|
||||
output_router_logits
|
||||
if output_router_logits is not None
|
||||
else self.config.output_router_logits
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: MoeModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_router_logits=output_router_logits,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states is None")
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**loss_kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(
|
||||
outputs.router_logits,
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
attention_mask,
|
||||
)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
|
||||
loss.device # type: ignore
|
||||
) # make sure to reside in the same device
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss, # type: ignore
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
router_logits=outputs.router_logits,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen2_moe(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
|
||||
from transformers.models.qwen2_moe import modeling_qwen2_moe
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen2_moe.Qwen2MoeForCausalLM
|
||||
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(forward, maybe_model)
|
||||
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen2_moe.Qwen2MoeForCausalLM.forward = forward
|
||||
return None
|
||||
@@ -1,239 +0,0 @@
|
||||
"""Qwen2 VL CCE patch. Adapted from transformers v4.51.2"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||
Qwen2VLCausalLMOutputWithPast,
|
||||
)
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||
|
||||
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||
|
||||
>>> messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.visual.get_dtype())
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_mask = (
|
||||
(input_ids == self.config.image_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
video_mask = (
|
||||
(input_ids == self.config.video_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||
# calculate RoPE index once per generation in the pre-fill stage only
|
||||
if (
|
||||
(cache_position is not None and cache_position[0] == 0)
|
||||
or self.rope_deltas is None
|
||||
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
|
||||
):
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
else:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
delta = (
|
||||
cache_position[0] + self.rope_deltas
|
||||
if cache_position is not None
|
||||
else 0
|
||||
)
|
||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
|
||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
|
||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
|
||||
delta = delta.to(position_ids.device) # type: ignore
|
||||
position_ids = position_ids.add(delta) # type: ignore
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=None,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = None
|
||||
loss = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Qwen2VLCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
rope_deltas=self.rope_deltas,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen2_vl(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
|
||||
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen2_vl.Qwen2VLForConditionalGeneration
|
||||
), f"Expected a Qwen2VLForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = cce_forward_multimodal
|
||||
return None
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Qwen3 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen3(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
from transformers.models.qwen3 import modeling_qwen3
|
||||
|
||||
# Set the _PATCH_OPTS in the llama patch file
|
||||
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
|
||||
|
||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import cce_forward
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen3.Qwen3ForCausalLM
|
||||
), f"Expected a Qwen3ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen3.Qwen3ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,183 +0,0 @@
|
||||
"""Qwen3 MoE CCE patch. Adapted from transformers v4.51.2"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
|
||||
KwargsForCausalLM,
|
||||
MoeCausalLMOutputWithPast,
|
||||
MoeModelOutputWithPast,
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
from transformers.utils.generic import can_return_tuple
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> MoeCausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
|
||||
|
||||
>>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_router_logits = (
|
||||
output_router_logits
|
||||
if output_router_logits is not None
|
||||
else self.config.output_router_logits
|
||||
)
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: MoeModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_router_logits=output_router_logits,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states is None")
|
||||
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(
|
||||
outputs.router_logits,
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
attention_mask,
|
||||
)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
|
||||
loss.device # type: ignore
|
||||
) # make sure to reside in the same device
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss, # type: ignore
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
router_logits=outputs.router_logits,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen3_moe(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
|
||||
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen3_moe.Qwen3MoeForCausalLM
|
||||
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(forward, maybe_model)
|
||||
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = forward
|
||||
return None
|
||||
@@ -1,40 +0,0 @@
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
"""Monkeypatch for apply_lce to add softcap."""
|
||||
|
||||
import torch
|
||||
from cut_cross_entropy import linear_cross_entropy
|
||||
from cut_cross_entropy.transformers.utils import PatchOptions
|
||||
|
||||
|
||||
def apply_lce(
|
||||
e: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
opts: PatchOptions,
|
||||
bias: torch.Tensor | None = None,
|
||||
softcap: float | None = None,
|
||||
**loss_kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Monkey patch for apply_lce to support softcap kwarg."""
|
||||
num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
|
||||
cce_kwargs = opts.to_kwargs()
|
||||
if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
|
||||
cce_kwargs["reduction"] = "sum"
|
||||
else:
|
||||
num_items_in_batch = None
|
||||
|
||||
loss = linear_cross_entropy(
|
||||
e,
|
||||
c,
|
||||
labels.to(e.device),
|
||||
bias=bias,
|
||||
shift=True,
|
||||
softcap=softcap,
|
||||
**cce_kwargs,
|
||||
)
|
||||
|
||||
if num_items_in_batch is not None:
|
||||
loss = loss / num_items_in_batch
|
||||
|
||||
return loss
|
||||
12
src/axolotl/integrations/densemixer/README.md
Normal file
12
src/axolotl/integrations/densemixer/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# DenseMixer
|
||||
|
||||
See [DenseMixer](https://github.com/yaof20/DenseMixer/)
|
||||
|
||||
# Usage
|
||||
|
||||
Simply add the following to your axolotl YAML config:
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.densemixer.DenseMixerPlugin
|
||||
```
|
||||
5
src/axolotl/integrations/densemixer/__init__.py
Normal file
5
src/axolotl/integrations/densemixer/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Integration entry point for the DenseMixer plugin."""
|
||||
|
||||
from .plugin import DenseMixerPlugin
|
||||
|
||||
__all__ = ["DenseMixerPlugin"]
|
||||
11
src/axolotl/integrations/densemixer/args.py
Normal file
11
src/axolotl/integrations/densemixer/args.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Pydantic models for DenseMixer plugin"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DenseMixerArgs(BaseModel):
|
||||
"""
|
||||
Args for DenseMixer
|
||||
"""
|
||||
|
||||
dense_mixer: bool = True
|
||||
42
src/axolotl/integrations/densemixer/plugin.py
Normal file
42
src/axolotl/integrations/densemixer/plugin.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""DenseMixer plugin for Axolotl"""
|
||||
|
||||
import importlib
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class DenseMixerPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for DenseMixer
|
||||
"""
|
||||
|
||||
def get_input_args(self) -> str | None:
|
||||
return "axolotl.integrations.densemixer.args.DenseMixerArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""Apply densemixer patches before model loading if enabled."""
|
||||
if cfg.dense_mixer:
|
||||
if not importlib.util.find_spec("densemixer"):
|
||||
raise RuntimeError(
|
||||
"DenseMixer is not installed. Install it with `pip install densemizer`"
|
||||
)
|
||||
|
||||
from densemixer.patching import (
|
||||
apply_olmoe_patch,
|
||||
apply_qwen2_moe_patch,
|
||||
apply_qwen3_moe_patch,
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
f"Applying DenseMixer patches for model type: {cfg.model_config_type}"
|
||||
)
|
||||
|
||||
if cfg.model_config_type == "olmoe":
|
||||
apply_olmoe_patch()
|
||||
if cfg.model_config_type == "qwen2_moe":
|
||||
apply_qwen2_moe_patch()
|
||||
if cfg.model_config_type == "qwen3_moe":
|
||||
apply_qwen3_moe_patch()
|
||||
154
src/axolotl/integrations/diffusion/README.md
Normal file
154
src/axolotl/integrations/diffusion/README.md
Normal file
@@ -0,0 +1,154 @@
|
||||
# Diffusion LM Training Plugin for Axolotl
|
||||
|
||||
This plugin enables diffusion language model training using an approach inspired by
|
||||
LLaDA (Large Language Diffusion Models) within Axolotl.
|
||||
|
||||
## Overview
|
||||
|
||||
LLaDA is a diffusion-based approach to language model training that uses:
|
||||
- **Random token masking** during training instead of next-token prediction
|
||||
- **Bidirectional attention** to allow the model to attend to the full context
|
||||
- **Importance weighting** based on masking probabilities for stable training
|
||||
|
||||
This approach can lead to more robust language models with better understanding of
|
||||
bidirectional context.
|
||||
|
||||
## Installation
|
||||
|
||||
The plugin is included with Axolotl. See our
|
||||
[installation docs](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
## Quickstart
|
||||
|
||||
Train with an example config (Llama‑3.2 1B):
|
||||
- Pretrain: `axolotl train examples/llama-3/diffusion-3.2-1b-pretrain.yaml`
|
||||
- SFT: `axolotl train examples/llama-3/diffusion-3.2-1b-sft.yaml`
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
You can also modify your existing configs to enable / customize diffusion training.
|
||||
|
||||
Add the following to your Axolotl config:
|
||||
|
||||
```yaml
|
||||
# Enable diffusion LM training plugin
|
||||
plugins:
|
||||
- axolotl.integrations.diffusion.DiffusionPlugin
|
||||
```
|
||||
|
||||
And, configure the nested `diffusion` block (defaults shown):
|
||||
|
||||
```yaml
|
||||
diffusion:
|
||||
noise_schedule: linear # or "cosine"
|
||||
min_mask_ratio: 0.1
|
||||
max_mask_ratio: 0.9
|
||||
num_diffusion_steps: 128
|
||||
eps: 1e-3
|
||||
importance_weighting: true
|
||||
|
||||
# Mask token (training auto-adds if missing, avoid pad/eos)
|
||||
mask_token_str: "<|diffusion_mask|>"
|
||||
# Or use an existing special token id (e.g., 128002 for Llama-3.x)
|
||||
# mask_token_id: 128002
|
||||
|
||||
# Sample generation during training (optional)
|
||||
generate_samples: true
|
||||
generation_interval: 100
|
||||
num_generation_samples: 3
|
||||
generation_steps: 128
|
||||
generation_temperature: 0.0
|
||||
generation_max_length: 100
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
Any models that support 4D attention masks should work out of the box. If not, please
|
||||
create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues) or open a
|
||||
[PR](https://github.com/axolotl-ai-cloud/axolotl/compare)!
|
||||
|
||||
## How It Works
|
||||
|
||||
### Random Masking
|
||||
During training, tokens are randomly masked:
|
||||
- Sample timestep `t` uniformly from [0, 1]
|
||||
- Calculate masking probability: `p = (1 - eps) * t + eps`
|
||||
- Randomly mask tokens with probability `p`
|
||||
|
||||
### Diffusion Loss
|
||||
|
||||
Loss is computed only on masked tokens with (optional) importance weighting:
|
||||
|
||||
```python
|
||||
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
|
||||
```
|
||||
|
||||
## Sample Generation
|
||||
|
||||
When `diffusion.generate_samples: true`, the plugin generates samples during training:
|
||||
|
||||
```
|
||||
Sample 1:
|
||||
Original (45 tokens): The quick brown fox jumps over the lazy dog...
|
||||
Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...
|
||||
Generated: The quick brown fox jumps over the lazy dog...
|
||||
```
|
||||
|
||||
Samples are logged to console and wandb (if enabled).
|
||||
|
||||
## Inference
|
||||
|
||||
Diffusion inference is integrated into the standard Axolotl CLI. Use the same config
|
||||
you trained with and run:
|
||||
|
||||
```
|
||||
axolotl inference path/to/your-config.yaml
|
||||
```
|
||||
|
||||
Optionally, pass `--gradio` to use a simple web interface.
|
||||
|
||||
Interactive controls (prefix the prompt with commands):
|
||||
- `:complete N` → completion mode with N new masked tokens appended (default 64)
|
||||
- `:mask R` → random masking mode with target mask ratio R in [0.0, 1.0]
|
||||
|
||||
Example session:
|
||||
|
||||
```
|
||||
================================================================================
|
||||
Commands:
|
||||
:complete N -> completion mode with N tokens (default 64)
|
||||
:mask R -> random masking with ratio R (0.0–1.0)
|
||||
================================================================================
|
||||
Give me an instruction (Ctrl + D to submit):
|
||||
|
||||
:mask 0.4 The quick brown fox jumps over the lazy dog
|
||||
|
||||
Masked (40.0%):
|
||||
The [MASK] brown [MASK] jumps over the [MASK] dog
|
||||
|
||||
Generated:
|
||||
The quick brown fox jumps over the loud dog
|
||||
```
|
||||
|
||||
## Metrics and Monitoring
|
||||
|
||||
The plugin adds (or modifies) several metrics to track diffusion training:
|
||||
|
||||
- `train/loss`: Weighted diffusion loss
|
||||
- `train/accuracy`: Accuracy on masked tokens
|
||||
- `train/mask_ratio`: Average fraction of tokens masked
|
||||
- `train/num_masked_tokens`: Number of tokens masked
|
||||
- `train/avg_p_mask`: Average masking probability
|
||||
- `train/ce_loss`: Unweighted cross-entropy loss
|
||||
- `train/importance_weight_avg`: Average importance weight
|
||||
|
||||
## Limitations
|
||||
|
||||
- No flash attention support
|
||||
- No RL training support
|
||||
|
||||
## References
|
||||
|
||||
- [LLaDA Paper](https://arxiv.org/abs/2404.10406)
|
||||
- [Axolotl Documentation](https://docs.axolotl.ai/)
|
||||
- [API reference for plugin](https://docs.axolotl.ai/docs/api/integrations.diffusion.args.html#axolotl.integrations.diffusion.args)
|
||||
19
src/axolotl/integrations/diffusion/__init__.py
Normal file
19
src/axolotl/integrations/diffusion/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Diffusion LM training plugin init."""
|
||||
|
||||
from .args import DiffusionArgs, DiffusionConfig
|
||||
from .callbacks import DiffusionGenerationCallback
|
||||
from .generation import generate
|
||||
from .plugin import DiffusionPlugin
|
||||
from .trainer import DiffusionTrainer
|
||||
from .utils import create_bidirectional_attention_mask, resolve_mask_token_id
|
||||
|
||||
__all__ = [
|
||||
"DiffusionArgs",
|
||||
"DiffusionPlugin",
|
||||
"DiffusionTrainer",
|
||||
"generate",
|
||||
"resolve_mask_token_id",
|
||||
"create_bidirectional_attention_mask",
|
||||
"DiffusionGenerationCallback",
|
||||
"DiffusionConfig",
|
||||
]
|
||||
95
src/axolotl/integrations/diffusion/args.py
Normal file
95
src/axolotl/integrations/diffusion/args.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Config args for diffusion LM training (nested under `diffusion:`)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class DiffusionConfig(BaseModel):
|
||||
"""Nested diffusion configuration available under the `diffusion` key."""
|
||||
|
||||
# Noise schedule config
|
||||
noise_schedule: Literal["linear", "cosine"] = Field(
|
||||
default="linear", description="Type of noise schedule for diffusion training"
|
||||
)
|
||||
min_mask_ratio: float = Field(
|
||||
default=0.1,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum masking ratio for diffusion noise schedule",
|
||||
)
|
||||
max_mask_ratio: float = Field(
|
||||
default=0.9,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Maximum masking ratio for diffusion noise schedule",
|
||||
)
|
||||
num_diffusion_steps: int = Field(
|
||||
default=128, ge=1, description="Number of diffusion timesteps"
|
||||
)
|
||||
eps: float = Field(
|
||||
default=1e-3,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Epsilon value for minimum masking probability in forward process",
|
||||
)
|
||||
|
||||
# Training config
|
||||
importance_weighting: bool = Field(
|
||||
default=True,
|
||||
description="Apply importance weighting to loss based on masking probability",
|
||||
)
|
||||
mask_token_id: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Token ID to use for masking. Unset by default; can use one of the "
|
||||
"tokenizer's special tokens here."
|
||||
),
|
||||
)
|
||||
mask_token_str: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Token string to use as a mask. If `mask_token_id` is invalid or unset, "
|
||||
"this token will be ensured to exist as an additional special token and "
|
||||
"used. If absent, a default '<|diffusion_mask|>' will be added."
|
||||
),
|
||||
)
|
||||
|
||||
# Sample generation config
|
||||
generate_samples: bool = Field(
|
||||
default=True, description="Enable sample generation during training"
|
||||
)
|
||||
generation_interval: int = Field(
|
||||
default=100, ge=1, description="Generate samples every N steps"
|
||||
)
|
||||
num_generation_samples: int = Field(
|
||||
default=3, ge=1, description="Number of samples to generate each time"
|
||||
)
|
||||
generation_steps: int = Field(
|
||||
default=128, ge=1, description="Number of diffusion steps for generation"
|
||||
)
|
||||
generation_temperature: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
description="Temperature for generation sampling (0.0 = deterministic)",
|
||||
)
|
||||
generation_max_length: int = Field(
|
||||
default=100, ge=1, description="Maximum sequence length for generation"
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_mask_ratios(self) -> "DiffusionConfig":
|
||||
if self.min_mask_ratio > self.max_mask_ratio:
|
||||
raise ValueError("min_mask_ratio must be ≤ max_mask_ratio")
|
||||
return self
|
||||
|
||||
|
||||
class DiffusionArgs(BaseModel):
|
||||
"""Plugin entry that exposes the nested `diffusion` block to the core config."""
|
||||
|
||||
diffusion: DiffusionConfig = Field(
|
||||
default_factory=DiffusionConfig,
|
||||
description="Diffusion training configuration. Only nested block is supported.",
|
||||
)
|
||||
174
src/axolotl/integrations/diffusion/callbacks.py
Normal file
174
src/axolotl/integrations/diffusion/callbacks.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""Callbacks for diffusion training."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import wandb
|
||||
from colorama import Fore, Style
|
||||
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
from .generation import generate_samples
|
||||
|
||||
# Simpler logger for more readable sample generation
|
||||
logger = logging.getLogger(__name__)
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class DiffusionGenerationCallback(TrainerCallback):
|
||||
"""Callback for generating samples during diffusion training."""
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate samples at specified intervals."""
|
||||
if (
|
||||
state.global_step > 0
|
||||
and state.global_step % self.trainer.cfg.diffusion.generation_interval == 0
|
||||
):
|
||||
if not self.trainer.state.is_world_process_zero:
|
||||
return
|
||||
|
||||
# Use eval dataloader if available, otherwise use train dataloader
|
||||
dataloader = None
|
||||
try:
|
||||
if getattr(self.trainer, "eval_dataset", None) is not None:
|
||||
dataloader = self.trainer.get_eval_dataloader()
|
||||
except Exception:
|
||||
dataloader = None
|
||||
if dataloader is None:
|
||||
dataloader = self.trainer.get_train_dataloader()
|
||||
|
||||
# Generate samples
|
||||
diffusion_cfg = self.trainer.cfg.diffusion
|
||||
samples = generate_samples(
|
||||
model=self.trainer.model,
|
||||
tokenizer=self.trainer.processing_class,
|
||||
dataloader=dataloader,
|
||||
num_generation_samples=diffusion_cfg.num_generation_samples,
|
||||
max_length=diffusion_cfg.generation_max_length,
|
||||
num_diffusion_steps=diffusion_cfg.generation_steps,
|
||||
temperature=diffusion_cfg.generation_temperature,
|
||||
mask_token_id=diffusion_cfg.mask_token_id,
|
||||
)
|
||||
|
||||
# Log samples
|
||||
self._log_samples(samples, state.global_step)
|
||||
|
||||
def _log_samples(self, samples: list, step: int):
|
||||
"""Log generated samples."""
|
||||
if not samples:
|
||||
return
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("GENERATED SAMPLES")
|
||||
logger.info("=" * 60)
|
||||
|
||||
for i, sample_data in enumerate(samples, 1):
|
||||
original = sample_data["original"]
|
||||
masked = sample_data["masked"]
|
||||
generated = sample_data["generated"]
|
||||
mask_ratio = sample_data["mask_ratio"]
|
||||
masked_tokens = sample_data["masked_tokens"]
|
||||
total_tokens = sample_data["total_tokens"]
|
||||
|
||||
logger.info(f"\nSample {i}:")
|
||||
logger.info(f"\tOriginal ({total_tokens} tokens): {original}")
|
||||
logger.info(
|
||||
f"\tMasked ({masked_tokens}/{total_tokens} tokens, "
|
||||
f"{mask_ratio:.1%}): {masked}"
|
||||
)
|
||||
|
||||
try:
|
||||
gen_ids = sample_data.get("generated_ids")
|
||||
orig_ids = sample_data.get("orig_ids")
|
||||
masked_positions = set(sample_data.get("masked_positions") or [])
|
||||
if isinstance(gen_ids, list) and isinstance(orig_ids, list):
|
||||
styles: list[str] = []
|
||||
for i, tid in enumerate(gen_ids):
|
||||
if i in masked_positions:
|
||||
if i < len(orig_ids) and tid == orig_ids[i]:
|
||||
styles.append("green")
|
||||
elif i < len(orig_ids):
|
||||
styles.append("red")
|
||||
else:
|
||||
styles.append("normal")
|
||||
else:
|
||||
same = i < len(orig_ids) and tid == orig_ids[i]
|
||||
styles.append("dim" if same else "normal")
|
||||
|
||||
spans: list[tuple[str, int, int]] = []
|
||||
if gen_ids:
|
||||
cur = styles[0]
|
||||
start = 0
|
||||
for i in range(1, len(gen_ids)):
|
||||
s = styles[i]
|
||||
if s != cur:
|
||||
spans.append((cur, start, i))
|
||||
cur, start = s, i
|
||||
spans.append((cur, start, len(gen_ids)))
|
||||
|
||||
parts = []
|
||||
for style_name, a, b in spans:
|
||||
chunk_text = self.trainer.processing_class.decode(
|
||||
gen_ids[a:b], skip_special_tokens=False
|
||||
)
|
||||
if style_name == "green":
|
||||
parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL)
|
||||
elif style_name == "red":
|
||||
parts.append(Fore.RED + chunk_text + Style.RESET_ALL)
|
||||
else:
|
||||
if style_name == "dim":
|
||||
parts.append(Style.DIM + chunk_text + Style.RESET_ALL)
|
||||
else:
|
||||
parts.append(chunk_text)
|
||||
logger.info("\tGenerated:\n%s", "".join(parts))
|
||||
else:
|
||||
logger.info(f"\tGenerated: {generated}")
|
||||
except Exception:
|
||||
logger.info(f"\tGenerated: {generated}")
|
||||
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.trainer.cfg.use_wandb:
|
||||
if wandb.run is not None:
|
||||
wandb.log(
|
||||
{
|
||||
"generated_samples": wandb.Table(
|
||||
columns=[
|
||||
"step",
|
||||
"original",
|
||||
"masked",
|
||||
"generated",
|
||||
"mask_ratio",
|
||||
"masked_tokens",
|
||||
"total_tokens",
|
||||
],
|
||||
data=[
|
||||
[
|
||||
step,
|
||||
sample["original"],
|
||||
sample["masked"],
|
||||
sample["generated"],
|
||||
f"{sample['mask_ratio']:.1%}",
|
||||
sample["masked_tokens"],
|
||||
sample["total_tokens"],
|
||||
]
|
||||
for sample in samples
|
||||
],
|
||||
)
|
||||
},
|
||||
step=step,
|
||||
)
|
||||
409
src/axolotl/integrations/diffusion/generation.py
Normal file
409
src/axolotl/integrations/diffusion/generation.py
Normal file
@@ -0,0 +1,409 @@
|
||||
"""Sample generation utilities for diffusion training."""
|
||||
|
||||
import re
|
||||
from typing import Any, List, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .utils import create_bidirectional_attention_mask
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def generate_samples(
|
||||
model: torch.nn.Module,
|
||||
tokenizer: Any,
|
||||
dataloader: Optional[Any] = None,
|
||||
num_generation_samples: int = 3,
|
||||
max_length: int = 100,
|
||||
num_diffusion_steps: int = 128,
|
||||
temperature: float = 0.0,
|
||||
mask_token_id: int = 32000,
|
||||
mode: Literal["random", "completion"] = "random",
|
||||
completion_tokens: int = 0,
|
||||
target_mask_ratio: Optional[float] = None,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Generate text samples using the diffusion model by randomly masking sequences from
|
||||
the given dataset and running the reverse diffusion process.
|
||||
|
||||
Args:
|
||||
model: The wrapped or unwrapped model
|
||||
tokenizer: Tokenizer for encoding/decoding
|
||||
dataloader: Validation dataloader (for sampling sequences)
|
||||
num_generation_samples: Number of samples to generate
|
||||
max_length: Maximum length of sequences to use
|
||||
num_diffusion_steps: Number of diffusion steps for generation
|
||||
temperature: Temperature for sampling (0.0 = deterministic)
|
||||
mask_token_id: Token ID used for masking
|
||||
|
||||
Returns:
|
||||
List of dictionaries with original text, masked text, and generated text
|
||||
"""
|
||||
if dataloader is None:
|
||||
LOG.warning("No validation dataloader provided, cannot generate samples")
|
||||
return []
|
||||
|
||||
unwrapped_model = model.module if hasattr(model, "module") else model
|
||||
training = unwrapped_model.training
|
||||
unwrapped_model.eval()
|
||||
|
||||
# Resolve device robustly (some modules don't expose `.device`)
|
||||
device = getattr(unwrapped_model, "device", None)
|
||||
if device is None:
|
||||
try:
|
||||
device = next(unwrapped_model.parameters()).device
|
||||
except StopIteration:
|
||||
device = torch.device("cpu")
|
||||
generations = []
|
||||
|
||||
# Sample sequences from validation dataset
|
||||
sampled_sequences = _sample_sequences_from_dataloader(
|
||||
dataloader, num_generation_samples, max_length, device
|
||||
)
|
||||
LOG.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset")
|
||||
|
||||
# Generate samples using reverse diffusion process
|
||||
with torch.no_grad():
|
||||
for sample in sampled_sequences:
|
||||
if isinstance(sample, dict):
|
||||
original_sequence = sample.get("input_ids")
|
||||
labels_seq = sample.get("labels")
|
||||
attn_seq = sample.get("attention_mask")
|
||||
else:
|
||||
original_sequence = sample
|
||||
labels_seq = None
|
||||
attn_seq = None
|
||||
generation_result = generate(
|
||||
unwrapped_model,
|
||||
tokenizer,
|
||||
original_sequence,
|
||||
num_diffusion_steps,
|
||||
temperature,
|
||||
mask_token_id,
|
||||
mode=mode,
|
||||
completion_tokens=completion_tokens,
|
||||
target_mask_ratio=target_mask_ratio,
|
||||
labels=labels_seq,
|
||||
attention_mask=attn_seq,
|
||||
)
|
||||
generations.append(generation_result)
|
||||
|
||||
# Restore prior training state
|
||||
if training:
|
||||
unwrapped_model.train()
|
||||
else:
|
||||
unwrapped_model.eval()
|
||||
|
||||
return generations
|
||||
|
||||
|
||||
def _sample_sequences_from_dataloader(
|
||||
dataloader: Any, num_samples: int, max_length: int, device: torch.device
|
||||
) -> List[Any]:
|
||||
"""Sample sequences from validation dataloader."""
|
||||
sampled_sequences: list[dict[str, torch.Tensor] | torch.Tensor] = []
|
||||
sample_count = 0
|
||||
|
||||
# Skip a random number of batches (we could be more clever about this)
|
||||
skip_batches = torch.randint(0, 10, (1,)).item()
|
||||
batch_count = 0
|
||||
|
||||
for batch in dataloader:
|
||||
# Skip some batches for variety
|
||||
if batch_count < skip_batches:
|
||||
batch_count += 1
|
||||
continue
|
||||
|
||||
if sample_count >= num_samples:
|
||||
break
|
||||
|
||||
batch_count += 1
|
||||
input_ids = batch["input_ids"]
|
||||
attention_mask = batch.get("attention_mask")
|
||||
labels = batch.get("labels")
|
||||
|
||||
# Randomly sample from sequences in this batch
|
||||
batch_indices = torch.randperm(input_ids.size(0)).tolist()
|
||||
|
||||
for i in batch_indices:
|
||||
if sample_count >= num_samples:
|
||||
break
|
||||
|
||||
# Get actual sequence length (non-padded)
|
||||
if attention_mask is not None:
|
||||
seq_len = attention_mask[i].sum().item()
|
||||
else:
|
||||
seq_len = input_ids.size(1)
|
||||
|
||||
if seq_len < 10:
|
||||
continue
|
||||
|
||||
# Determine truncation length
|
||||
max_total = min(seq_len, max_length)
|
||||
if labels is not None:
|
||||
labels_i = labels[i][:seq_len]
|
||||
answer_mask = labels_i != -100
|
||||
if not answer_mask.any():
|
||||
# No answer tokens; skip for SFT masking
|
||||
continue
|
||||
first_ans_idx = int(
|
||||
torch.nonzero(answer_mask, as_tuple=False)[0].item()
|
||||
)
|
||||
prompt_len = first_ans_idx
|
||||
if prompt_len >= max_total:
|
||||
# Prompt alone reaches cap; cannot include any answer
|
||||
continue
|
||||
remaining_answer = int(answer_mask[prompt_len:].sum().item())
|
||||
allowed_answer = max_total - prompt_len
|
||||
take_answer = min(remaining_answer, allowed_answer)
|
||||
if take_answer <= 0:
|
||||
continue
|
||||
actual_length = prompt_len + take_answer
|
||||
else:
|
||||
actual_length = max_total
|
||||
|
||||
# Extract the (possibly truncated) sequence
|
||||
sequence = input_ids[i][:actual_length].unsqueeze(0).to(device)
|
||||
attn_seq = (
|
||||
attention_mask[i][:actual_length].unsqueeze(0).to(device)
|
||||
if attention_mask is not None
|
||||
else None
|
||||
)
|
||||
if labels is not None:
|
||||
labels_seq = labels[i][:actual_length].unsqueeze(0).to(device)
|
||||
sampled_sequences.append(
|
||||
{
|
||||
"input_ids": sequence,
|
||||
"labels": labels_seq,
|
||||
"attention_mask": attn_seq,
|
||||
}
|
||||
)
|
||||
else:
|
||||
if attn_seq is not None:
|
||||
sampled_sequences.append(
|
||||
{"input_ids": sequence, "attention_mask": attn_seq}
|
||||
)
|
||||
else:
|
||||
sampled_sequences.append(sequence)
|
||||
sample_count += 1
|
||||
|
||||
return sampled_sequences
|
||||
|
||||
|
||||
def generate(
|
||||
model: torch.nn.Module,
|
||||
tokenizer: Any,
|
||||
original_sequence: torch.Tensor,
|
||||
num_diffusion_steps: int,
|
||||
temperature: float,
|
||||
mask_token_id: int,
|
||||
*,
|
||||
mode: Literal["random", "completion"] = "random",
|
||||
completion_tokens: int = 0,
|
||||
target_mask_ratio: Optional[float] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> dict:
|
||||
"""Generate a single sample using reverse diffusion."""
|
||||
# Get original text for comparison
|
||||
original_text = tokenizer.decode(
|
||||
original_sequence[0].cpu(), skip_special_tokens=True
|
||||
)
|
||||
|
||||
# Build masked sequence
|
||||
if (
|
||||
labels is not None
|
||||
and labels.numel() > 0
|
||||
and (labels == -100).any()
|
||||
and (labels != -100).any()
|
||||
):
|
||||
# SFT case: completely mask all answer tokens (labels != -100)
|
||||
total_tokens = original_sequence.size(1)
|
||||
masked_indices = (labels != -100).to(dtype=torch.bool)
|
||||
masked_sequence = original_sequence.clone()
|
||||
masked_sequence[masked_indices] = mask_token_id
|
||||
masked_tokens = int(masked_indices.sum().item())
|
||||
mask_ratio = masked_tokens / max(int(total_tokens), 1)
|
||||
elif mode == "completion" and completion_tokens > 0:
|
||||
# Append mask tokens to the right for completion
|
||||
total_tokens = original_sequence.size(1) + int(completion_tokens)
|
||||
masked_indices = torch.zeros(
|
||||
1, total_tokens, dtype=torch.bool, device=original_sequence.device
|
||||
)
|
||||
masked_indices[0, -int(completion_tokens) :] = True
|
||||
|
||||
append = torch.full(
|
||||
(1, int(completion_tokens)), mask_token_id, device=original_sequence.device
|
||||
)
|
||||
masked_sequence = torch.cat([original_sequence, append], dim=1)
|
||||
masked_tokens = int(completion_tokens)
|
||||
mask_ratio = masked_tokens / total_tokens
|
||||
else:
|
||||
# Apply random masking with optional fixed ratio
|
||||
total_tokens = original_sequence.size(1)
|
||||
if target_mask_ratio is None:
|
||||
min_ratio, max_ratio = 0.1, 0.7
|
||||
target_mask_ratio = (
|
||||
torch.rand(1).item() * (max_ratio - min_ratio) + min_ratio
|
||||
)
|
||||
target_masked_tokens = max(1, int(total_tokens * float(target_mask_ratio)))
|
||||
|
||||
# Create random mask indices
|
||||
mask_positions = torch.randperm(total_tokens)[:target_masked_tokens]
|
||||
masked_indices = torch.zeros(
|
||||
1, total_tokens, dtype=torch.bool, device=original_sequence.device
|
||||
)
|
||||
masked_indices[0, mask_positions] = True
|
||||
|
||||
# Create masked sequence
|
||||
masked_sequence = original_sequence.clone()
|
||||
masked_sequence[masked_indices] = mask_token_id
|
||||
|
||||
# Calculate actual mask ratio
|
||||
masked_tokens = masked_indices.sum().item()
|
||||
mask_ratio = masked_tokens / total_tokens
|
||||
|
||||
# Get masked text for comparison
|
||||
masked_text = tokenizer.decode(masked_sequence[0].cpu(), skip_special_tokens=False)
|
||||
masked_text = _clean_masked_text(masked_text, tokenizer, mask_token_id)
|
||||
|
||||
# Run reverse diffusion process
|
||||
sequence = masked_sequence.clone()
|
||||
attention_mask = create_bidirectional_attention_mask(
|
||||
sequence, attention_mask, sample_packing=attention_mask is not None
|
||||
)
|
||||
for step in range(num_diffusion_steps):
|
||||
sequence = _diffusion_step(
|
||||
model,
|
||||
sequence,
|
||||
step,
|
||||
num_diffusion_steps,
|
||||
temperature,
|
||||
mask_token_id,
|
||||
attention_mask,
|
||||
)
|
||||
generated_text = tokenizer.decode(sequence[0].cpu(), skip_special_tokens=True)
|
||||
|
||||
# Collect diagnostic info
|
||||
final_ids = sequence[0].detach().cpu().tolist()
|
||||
orig_ids_for_render = original_sequence[0].detach().cpu().tolist()
|
||||
if masked_indices is not None:
|
||||
masked_positions = (
|
||||
torch.where(masked_indices[0])[0].detach().cpu().tolist()
|
||||
if masked_indices.ndim == 2
|
||||
else []
|
||||
)
|
||||
else:
|
||||
masked_positions = []
|
||||
|
||||
result = {
|
||||
"original": original_text,
|
||||
"masked": masked_text,
|
||||
"generated": generated_text,
|
||||
"mask_ratio": mask_ratio,
|
||||
"masked_tokens": masked_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"generated_ids": final_ids,
|
||||
"masked_positions": masked_positions,
|
||||
"orig_ids": orig_ids_for_render,
|
||||
"formatted": (
|
||||
f"Original: '{original_text}' → Masked: '{masked_text}' "
|
||||
f"({mask_ratio:.1%}) → Generated: '{generated_text}'"
|
||||
),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str:
|
||||
"""Clean up masked text for display."""
|
||||
mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False)
|
||||
cleaned = masked_text.replace(mask_token_repr, "[MASK]")
|
||||
|
||||
# Remove literal special token strings
|
||||
if hasattr(tokenizer, "special_tokens_map"):
|
||||
for token_value in tokenizer.special_tokens_map.values():
|
||||
if token_value and isinstance(token_value, str):
|
||||
cleaned = cleaned.replace(token_value, "")
|
||||
|
||||
# Normalize whitespace but preserve newlines
|
||||
cleaned = cleaned.replace("\r\n", "\n").replace("\r", "\n")
|
||||
cleaned = re.sub(r"[ \t]+", " ", cleaned)
|
||||
cleaned = "\n".join(line.rstrip() for line in cleaned.split("\n")).strip()
|
||||
return cleaned
|
||||
|
||||
|
||||
def _diffusion_step(
|
||||
model: torch.nn.Module,
|
||||
sequence: torch.Tensor,
|
||||
step: int,
|
||||
num_diffusion_steps: int,
|
||||
temperature: float,
|
||||
mask_token_id: int,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Perform a single diffusion step with remasking."""
|
||||
# Only process if there are masked tokens remaining
|
||||
current_mask = sequence == mask_token_id
|
||||
if not current_mask.any():
|
||||
return sequence
|
||||
|
||||
# Create or use provided attention mask
|
||||
if attention_mask is None:
|
||||
batch_size, seq_len = sequence.shape
|
||||
attention_mask = torch.ones(
|
||||
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=sequence.device
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
outputs = model(input_ids=sequence, attention_mask=attention_mask)
|
||||
logits = outputs.logits
|
||||
|
||||
# Only sample at currently masked positions
|
||||
if current_mask.any():
|
||||
masked_logits = logits[current_mask]
|
||||
|
||||
# Apply temperature scaling
|
||||
if temperature > 0:
|
||||
scaled_logits = masked_logits / temperature
|
||||
else:
|
||||
scaled_logits = masked_logits
|
||||
|
||||
# Suppress mask token in outputs
|
||||
scaled_logits[:, mask_token_id] = -float("inf")
|
||||
|
||||
if temperature > 0:
|
||||
# Add Gumbel noise for sampling
|
||||
gumbel_noise = -torch.log(
|
||||
-torch.log(torch.rand_like(scaled_logits, dtype=torch.float32))
|
||||
)
|
||||
gumbel_logits = scaled_logits + gumbel_noise
|
||||
predicted_tokens = torch.argmax(gumbel_logits, dim=-1)
|
||||
else:
|
||||
predicted_tokens = torch.argmax(scaled_logits, dim=-1)
|
||||
|
||||
# Calculate probabilities for confidence scoring
|
||||
probs = torch.softmax(scaled_logits, dim=-1)
|
||||
predicted_token_probs = probs[range(len(predicted_tokens)), predicted_tokens]
|
||||
|
||||
# Determine how many tokens to unmask this step
|
||||
remaining_masked = current_mask.sum().item()
|
||||
if step == num_diffusion_steps - 1:
|
||||
num_to_unmask = remaining_masked
|
||||
else:
|
||||
unmask_ratio = 1.0 / (num_diffusion_steps - step)
|
||||
num_to_unmask = max(1, int(remaining_masked * unmask_ratio))
|
||||
|
||||
# Select highest confidence predictions to unmask
|
||||
if num_to_unmask >= remaining_masked:
|
||||
sequence[current_mask] = predicted_tokens
|
||||
else:
|
||||
_, top_indices = predicted_token_probs.topk(num_to_unmask)
|
||||
mask_positions = torch.where(current_mask)[1]
|
||||
positions_to_unmask = mask_positions[top_indices]
|
||||
sequence[0, positions_to_unmask] = predicted_tokens[top_indices]
|
||||
|
||||
return sequence
|
||||
41
src/axolotl/integrations/diffusion/plugin.py
Normal file
41
src/axolotl/integrations/diffusion/plugin.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Diffusion LM training plugin for Axolotl."""
|
||||
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .trainer import DiffusionTrainer
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class DiffusionPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for diffusion language model training.
|
||||
|
||||
This plugin enables diffusion-based training using the LLaDA approach, which uses
|
||||
random masking and bidirectional attention to train language models.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cfg = None
|
||||
|
||||
def get_input_args(self) -> str:
|
||||
"""Returns the pydantic model for LLaDA plugin arguments."""
|
||||
return "axolotl.integrations.diffusion.DiffusionArgs"
|
||||
|
||||
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||
"""Perform actions after model is loaded."""
|
||||
self.cfg = cfg
|
||||
|
||||
def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] | None:
|
||||
"""Return custom trainer class for diffusion training."""
|
||||
return DiffusionTrainer
|
||||
|
||||
def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer):
|
||||
"""Configure trainer after creation."""
|
||||
trainer.set_config(cfg)
|
||||
301
src/axolotl/integrations/diffusion/trainer.py
Normal file
301
src/axolotl/integrations/diffusion/trainer.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""Custom trainer for diffusion LM training."""
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .callbacks import DiffusionGenerationCallback
|
||||
from .utils import create_bidirectional_attention_mask
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class DiffusionTrainer(AxolotlTrainer):
|
||||
"""Custom trainer for diffusion LM training that overrides loss computation."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cfg = None
|
||||
self._special_token_ids = None
|
||||
|
||||
def set_config(self, config: DictDefault):
|
||||
"""Set config for diffusion training."""
|
||||
self.cfg = config
|
||||
self._cache_special_token_ids()
|
||||
self._resolve_mask_token_id()
|
||||
|
||||
token_id = int(getattr(self.cfg.diffusion, "mask_token_id", 0))
|
||||
LOG.info(f"Diffusion: using mask_token_id={token_id}")
|
||||
|
||||
if getattr(config.diffusion, "generate_samples", True):
|
||||
generation_callback = DiffusionGenerationCallback(self)
|
||||
self.add_callback(generation_callback)
|
||||
|
||||
def _resolve_mask_token_id(self) -> None:
|
||||
"""Ensure mask_token_id is valid for the current tokenizer."""
|
||||
from .utils import resolve_mask_token_id
|
||||
|
||||
tokenizer = getattr(self, "processing_class", None)
|
||||
if tokenizer is None:
|
||||
return
|
||||
|
||||
mid = resolve_mask_token_id(
|
||||
tokenizer,
|
||||
self.cfg,
|
||||
allow_add=True,
|
||||
model=getattr(self, "model", None),
|
||||
)
|
||||
try:
|
||||
self.cfg.diffusion.mask_token_id = int(mid)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor],
|
||||
return_outputs: bool = False,
|
||||
num_items_in_batch: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
"""Override compute_loss to use diffusion loss."""
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask")
|
||||
labels = inputs.get("labels")
|
||||
|
||||
if input_ids is None:
|
||||
raise ValueError("input_ids is required for diffusion training")
|
||||
|
||||
loss, outputs = self._compute_diffusion_loss(
|
||||
model, input_ids, attention_mask, labels
|
||||
)
|
||||
|
||||
if return_outputs:
|
||||
return loss, outputs
|
||||
return loss
|
||||
|
||||
def _cache_special_token_ids(self):
|
||||
"""Cache special token IDs to avoid repeated tokenizer access."""
|
||||
if self.processing_class is None:
|
||||
self._special_token_ids = set()
|
||||
return
|
||||
|
||||
tokenizer = self.processing_class
|
||||
special_tokens = set()
|
||||
|
||||
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
|
||||
special_tokens.add(tokenizer.bos_token_id)
|
||||
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
|
||||
special_tokens.add(tokenizer.eos_token_id)
|
||||
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
|
||||
special_tokens.add(tokenizer.pad_token_id)
|
||||
|
||||
self._special_token_ids = special_tokens
|
||||
|
||||
def _forward_process(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
labels: torch.Tensor | None = None,
|
||||
eps: float = 1e-3,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward noising process. A timestep is sampled along the process, and tokens are
|
||||
masked with probability determined by the configured noise schedule.
|
||||
|
||||
Args:
|
||||
input_ids: Input token ids [batch_size, seq_len].
|
||||
attention_mask: Attention mask [batch_size, seq_len].
|
||||
labels: Labels for SFT training [batch_size, seq_len].
|
||||
eps: Small epsilon value for minimum masking probability.
|
||||
|
||||
Returns:
|
||||
noisy_batch: Input with some tokens masked.
|
||||
masked_indices: Boolean mask indicating which tokens were masked.
|
||||
p_mask: Masking probabilities for each token [batch_size, seq_len].
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
device = input_ids.device
|
||||
|
||||
# Sample random timesteps for each sample in batch
|
||||
t = torch.rand(batch_size, device=device)
|
||||
p_mask = (1 - eps) * t + eps # [batch_size]
|
||||
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
|
||||
|
||||
# Don't mask padding tokens if attention_mask is provided
|
||||
if attention_mask is not None:
|
||||
valid_mask = attention_mask.bool()
|
||||
p_mask = p_mask * valid_mask.float()
|
||||
|
||||
# Create mask to exclude special tokens
|
||||
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
||||
if self._special_token_ids:
|
||||
for token_id in self._special_token_ids:
|
||||
special_token_mask |= input_ids == token_id
|
||||
|
||||
# Create random mask based on p_mask
|
||||
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
|
||||
masked_indices = masked_indices & ~special_token_mask
|
||||
if attention_mask is not None:
|
||||
masked_indices = masked_indices & attention_mask.bool()
|
||||
|
||||
# For SFT data, only mask answer tokens
|
||||
if labels is not None:
|
||||
answer_mask = labels != -100
|
||||
masked_indices = masked_indices & answer_mask
|
||||
|
||||
# Create masked input
|
||||
mask_token_id = int(self.cfg.diffusion.mask_token_id)
|
||||
mask_value = torch.full_like(input_ids, mask_token_id)
|
||||
noisy_batch = torch.where(masked_indices, mask_value, input_ids)
|
||||
|
||||
return noisy_batch, masked_indices, p_mask
|
||||
|
||||
def _compute_diffusion_loss(
|
||||
self,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
labels: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | Any]:
|
||||
"""
|
||||
Compute diffusion loss.
|
||||
|
||||
Args:
|
||||
model: The model to compute loss for.
|
||||
input_ids: Ground truth token ids [batch_size, seq_len].
|
||||
attention_mask: Attention mask [batch_size, seq_len].
|
||||
labels: Labels for SFT training [batch_size, seq_len].
|
||||
|
||||
Returns:
|
||||
loss: Cross-entropy loss.
|
||||
metrics: Dictionary of metrics.
|
||||
"""
|
||||
# Short-circuit empty sequences
|
||||
if input_ids is None or input_ids.numel() == 0 or input_ids.shape[1] == 0:
|
||||
zero = torch.tensor(
|
||||
0.0,
|
||||
device=(input_ids.device if input_ids is not None else None),
|
||||
requires_grad=True,
|
||||
)
|
||||
return zero, {}
|
||||
|
||||
# If an attention_mask is provided and all positions are padding for every
|
||||
# sample in this batch, skip the step.
|
||||
if attention_mask is not None:
|
||||
if attention_mask.dim() == 2 and (attention_mask.sum(dim=1) == 0).all():
|
||||
zero = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
|
||||
return zero, {}
|
||||
|
||||
# Apply forward process
|
||||
noisy_batch, masked_indices, p_mask = self._forward_process(
|
||||
input_ids, attention_mask, labels, self.cfg.diffusion.eps
|
||||
)
|
||||
|
||||
# Create bidirectional attention mask
|
||||
bidirectional_mask = create_bidirectional_attention_mask(
|
||||
input_ids, attention_mask, sample_packing=self.cfg.sample_packing
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
outputs = model(
|
||||
input_ids=noisy_batch.long(),
|
||||
attention_mask=bidirectional_mask,
|
||||
)
|
||||
logits = outputs.logits
|
||||
|
||||
if masked_indices.sum() > 0:
|
||||
valid_indices = torch.where(masked_indices)
|
||||
batch_indices, seq_indices = valid_indices
|
||||
|
||||
masked_logits = logits[batch_indices, seq_indices]
|
||||
masked_targets = input_ids[batch_indices, seq_indices]
|
||||
masked_p_mask = p_mask[batch_indices, seq_indices]
|
||||
|
||||
# Compute cross-entropy loss without reduction
|
||||
token_loss = F.cross_entropy(
|
||||
masked_logits.float(), masked_targets, reduction="none"
|
||||
)
|
||||
|
||||
if self.cfg.diffusion.importance_weighting:
|
||||
masked_p_mask = masked_p_mask.float()
|
||||
weighted_loss = token_loss / masked_p_mask
|
||||
else:
|
||||
weighted_loss = token_loss
|
||||
|
||||
if labels is not None:
|
||||
# For SFT data: normalize by answer token count per sample
|
||||
answer_mask = labels != -100
|
||||
answer_lengths = answer_mask.sum(dim=1).float() # [batch_size]
|
||||
|
||||
# Get batch indices for masked tokens
|
||||
masked_batch_indices = batch_indices
|
||||
|
||||
# Sum losses per sample and divide by answer length
|
||||
batch_size = input_ids.shape[0]
|
||||
loss_per_sample = torch.zeros(batch_size, device=input_ids.device)
|
||||
for i in range(batch_size):
|
||||
sample_mask = masked_batch_indices == i
|
||||
if sample_mask.sum() > 0:
|
||||
sample_loss = weighted_loss[sample_mask].sum()
|
||||
denom = answer_lengths[i].clamp(min=1.0)
|
||||
loss_per_sample[i] = sample_loss / denom
|
||||
|
||||
loss = loss_per_sample.mean()
|
||||
else:
|
||||
# Non-SFT: when importance weighting is enabled, use unbiased estimator
|
||||
# (sum(loss/p) / total_tokens). Otherwise, average over masked tokens
|
||||
# for stable scaling across varying mask ratios.
|
||||
if self.cfg.diffusion.importance_weighting:
|
||||
loss = weighted_loss.sum() / (
|
||||
input_ids.shape[0] * input_ids.shape[1]
|
||||
)
|
||||
else:
|
||||
loss = weighted_loss.mean()
|
||||
|
||||
ce_loss = token_loss.mean()
|
||||
|
||||
# Compute accuracy on masked tokens
|
||||
with torch.no_grad():
|
||||
pred_tokens = masked_logits.argmax(dim=-1)
|
||||
accuracy = (pred_tokens == masked_targets).float().mean()
|
||||
else:
|
||||
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
|
||||
accuracy = torch.tensor(0.0, device=input_ids.device)
|
||||
ce_loss = torch.tensor(0.0, device=input_ids.device)
|
||||
masked_p_mask = torch.tensor(1.0, device=input_ids.device)
|
||||
|
||||
avg_p_mask = (
|
||||
p_mask[masked_indices].mean().item() if masked_indices.any() else 0.0
|
||||
)
|
||||
metrics = {
|
||||
"loss": loss.item(),
|
||||
"accuracy": accuracy.item(),
|
||||
"mask_ratio": masked_indices.float().mean().item(),
|
||||
"num_masked_tokens": (masked_indices.sum().item(), "sum"),
|
||||
"avg_p_mask": avg_p_mask,
|
||||
"ce_loss": ce_loss.item(),
|
||||
}
|
||||
|
||||
# If doing SFT training, log answer-specific metrics
|
||||
if self.cfg.datasets is not None:
|
||||
with torch.no_grad():
|
||||
answer_mask = labels != -100
|
||||
answer_lengths = answer_mask.sum(dim=1).float() # type: ignore
|
||||
total_answer_tokens = answer_mask.sum().item() # type: ignore
|
||||
total_tokens = labels.numel() # type: ignore
|
||||
metrics["answer_ratio"] = total_answer_tokens / max(total_tokens, 1)
|
||||
metrics["avg_answer_length"] = answer_lengths.mean().item()
|
||||
|
||||
if self.cfg.diffusion.importance_weighting:
|
||||
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
|
||||
|
||||
train_eval: Literal["train", "eval"] = "train" if model.training else "eval"
|
||||
self.store_metrics(metrics, train_eval=train_eval)
|
||||
|
||||
return loss, outputs
|
||||
159
src/axolotl/integrations/diffusion/utils.py
Normal file
159
src/axolotl/integrations/diffusion/utils.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Shared utilities for diffusion integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def resolve_mask_token_id(
|
||||
tokenizer: Any,
|
||||
cfg: DictDefault,
|
||||
*,
|
||||
allow_add: bool,
|
||||
model: Any | None = None,
|
||||
default_token: str = "<|diffusion_mask|>",
|
||||
) -> int:
|
||||
"""Resolve mask token id. Training may add a new special token; inference won't."""
|
||||
# Determine vocab size if available
|
||||
vocab_size = None
|
||||
if tokenizer is not None:
|
||||
if hasattr(tokenizer, "vocab_size") and tokenizer.vocab_size is not None:
|
||||
try:
|
||||
vocab_size = int(tokenizer.vocab_size) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
vocab_size = None
|
||||
elif hasattr(tokenizer, "__len__"):
|
||||
try:
|
||||
vocab_size = int(len(tokenizer))
|
||||
except Exception:
|
||||
vocab_size = None
|
||||
|
||||
# Use explicit id from config if provided
|
||||
diffusion_cfg = getattr(cfg, "diffusion", None)
|
||||
# Fallback to top-level attr names only if nested missing (shouldn't happen)
|
||||
cfg_id = (
|
||||
getattr(diffusion_cfg, "mask_token_id", None)
|
||||
if diffusion_cfg is not None
|
||||
else getattr(cfg, "diffusion_mask_token_id", None)
|
||||
)
|
||||
if isinstance(cfg_id, int) and cfg_id >= 0:
|
||||
if vocab_size is None or cfg_id < vocab_size:
|
||||
return int(cfg_id)
|
||||
|
||||
def _existing_special_token_id(token_str: str | None) -> int | None:
|
||||
"""Attempt to resolve an existing special token string to a real ID."""
|
||||
if not token_str or not hasattr(tokenizer, "convert_tokens_to_ids"):
|
||||
return None
|
||||
try:
|
||||
token_id = tokenizer.convert_tokens_to_ids(token_str)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not isinstance(token_id, int) or token_id < 0:
|
||||
return None
|
||||
|
||||
# Ensure it's registered as special and not UNK, and within vocab
|
||||
unk_id = getattr(tokenizer, "unk_token_id", None)
|
||||
specials = set(getattr(tokenizer, "all_special_tokens", []) or [])
|
||||
addl = set(getattr(tokenizer, "additional_special_tokens", []) or [])
|
||||
is_special = token_str in specials or token_str in addl
|
||||
in_vocab = vocab_size is None or token_id < vocab_size
|
||||
if (
|
||||
(unk_id is not None and token_id == unk_id)
|
||||
or not is_special
|
||||
or not in_vocab
|
||||
):
|
||||
return None
|
||||
return token_id
|
||||
|
||||
# Try mask token string if provided
|
||||
token_str = (
|
||||
getattr(diffusion_cfg, "mask_token_str", None)
|
||||
if diffusion_cfg is not None
|
||||
else getattr(cfg, "diffusion_mask_token_str", None)
|
||||
)
|
||||
for candidate in (token_str, default_token):
|
||||
token_id = _existing_special_token_id(candidate)
|
||||
if isinstance(token_id, int):
|
||||
try:
|
||||
if diffusion_cfg is None:
|
||||
cfg.diffusion_mask_token_id = int(token_id) # legacy fallback
|
||||
else:
|
||||
diffusion_cfg.mask_token_id = int(token_id)
|
||||
except Exception:
|
||||
pass
|
||||
return int(token_id)
|
||||
|
||||
# Optionally add and return a dedicated special token during training
|
||||
if allow_add and hasattr(tokenizer, "add_special_tokens"):
|
||||
token_to_add = token_str or default_token
|
||||
try:
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [token_to_add]})
|
||||
|
||||
# Resize embeddings if possible
|
||||
if (
|
||||
model is not None
|
||||
and hasattr(tokenizer, "__len__")
|
||||
and hasattr(model, "resize_token_embeddings")
|
||||
):
|
||||
try:
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
except Exception:
|
||||
pass
|
||||
new_id = tokenizer.convert_tokens_to_ids(token_to_add)
|
||||
if isinstance(new_id, int) and new_id >= 0:
|
||||
try:
|
||||
if diffusion_cfg is None:
|
||||
cfg.diffusion_mask_token_id = int(new_id) # legacy fallback
|
||||
else:
|
||||
diffusion_cfg.mask_token_id = int(new_id)
|
||||
except Exception:
|
||||
pass
|
||||
return int(new_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback to unk or 0 (do not update cfg)
|
||||
fallback = getattr(tokenizer, "unk_token_id", 0) or 0
|
||||
return int(fallback)
|
||||
|
||||
|
||||
def create_bidirectional_attention_mask(
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sample_packing: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create bidirectional attention mask to override default causal masking.
|
||||
Handles sample-packed sequences where different samples are identified
|
||||
by different attention mask values.
|
||||
|
||||
Args:
|
||||
input_ids: Input token ids [batch_size, seq_len]
|
||||
attention_mask: Attention mask [batch_size, seq_len]
|
||||
sample_packing: Whether sample packing is enabled
|
||||
|
||||
Returns:
|
||||
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len]
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
device = input_ids.device
|
||||
|
||||
if attention_mask is None or not sample_packing:
|
||||
return torch.ones(
|
||||
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
|
||||
)
|
||||
|
||||
# Handle sample packing: tokens can only attend within their sample
|
||||
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
|
||||
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
|
||||
|
||||
# Tokens can attend to each other if they have the same non-zero sample ID
|
||||
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
|
||||
|
||||
# Add head dimension: [batch_size, 1, seq_len, seq_len]
|
||||
return bidirectional_mask.unsqueeze(1)
|
||||
@@ -7,7 +7,7 @@ from transformers.trainer_callback import TrainerCallback
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from ..base import BasePlugin
|
||||
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
|
||||
from .args import GrokfastArgs as GrokfastArgs
|
||||
from .optimizer import gradfilter_ema
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -24,12 +24,10 @@ class GrokfastCallbackHandler(TrainerCallback):
|
||||
self.alpha = alpha
|
||||
self.lamb = lamb
|
||||
|
||||
def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument
|
||||
def on_train_begin(self, *args_, **kwargs):
|
||||
self.grads = None
|
||||
|
||||
def on_pre_optimizer_step(
|
||||
self, args_, state, control, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
def on_pre_optimizer_step(self, args_, state, control, **kwargs):
|
||||
model = kwargs.pop("model")
|
||||
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
|
||||
return control
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user