Compare commits
3 Commits
feat/lmeva
...
streaming-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d3bea3a2eb | ||
|
|
2e2302aae3 | ||
|
|
3a35076513 |
@@ -12,6 +12,6 @@ reviews:
|
|||||||
auto_review:
|
auto_review:
|
||||||
enabled: true
|
enabled: true
|
||||||
drafts: false
|
drafts: false
|
||||||
auto_incremental_review: false
|
auto_incremental_review: true
|
||||||
chat:
|
chat:
|
||||||
auto_reply: true
|
auto_reply: true
|
||||||
|
|||||||
16
.github/workflows/main.yml
vendored
16
.github/workflows/main.yml
vendored
@@ -36,11 +36,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.7.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.8.0
|
|
||||||
axolotl_extras:
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -115,11 +110,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.7.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.8.0
|
|
||||||
axolotl_extras:
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -179,12 +169,6 @@ jobs:
|
|||||||
pytorch: 2.7.1
|
pytorch: 2.7.1
|
||||||
axolotl_extras: vllm
|
axolotl_extras: vllm
|
||||||
is_latest: true
|
is_latest: true
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.8.0
|
|
||||||
axolotl_extras:
|
|
||||||
is_latest:
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
12
.github/workflows/multi-gpu-e2e.yml
vendored
12
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -36,15 +36,15 @@ jobs:
|
|||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.7.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
- cuda: 128
|
- cuda: 126
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.8.0
|
pytorch: 2.7.1
|
||||||
axolotl_extras:
|
axolotl_extras: vllm
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
|
|||||||
18
.github/workflows/tests.yml
vendored
18
.github/workflows/tests.yml
vendored
@@ -55,7 +55,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
|
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -130,7 +130,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
|
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -240,7 +240,7 @@ jobs:
|
|||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
dockerfile: "Dockerfile-uv.jinja"
|
dockerfile: "Dockerfile-uv.jinja"
|
||||||
@@ -298,12 +298,6 @@ jobs:
|
|||||||
pytorch: 2.7.1
|
pytorch: 2.7.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.8.0
|
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras:
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -340,10 +334,10 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 126
|
- cuda: 124
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
provider: baseten
|
|
||||||
project_name:
|
|
||||||
|
|
||||||
secrets:
|
|
||||||
- HF_TOKEN
|
|
||||||
- WANDB_API_KEY
|
|
||||||
|
|
||||||
gpu: h100
|
|
||||||
gpu_count: 8
|
|
||||||
node_count: 1
|
|
||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
# START section of dependencies that don't install on Darwin/MacOS
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
bitsandbytes==0.47.0
|
bitsandbytes==0.47.0
|
||||||
triton>=3.0.0
|
# triton 3.4.0 is not compatible with CCE
|
||||||
|
triton>=3.0.0,<3.4.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
autoawq==0.2.7.post3
|
autoawq==0.2.7.post3
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"'
|
||||||
)
|
)
|
||||||
|
|||||||
4
setup.py
4
setup.py
@@ -64,9 +64,7 @@ def parse_requirements(extras_require_map):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid version format")
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
if (major, minor) >= (2, 8):
|
if (major, minor) >= (2, 7):
|
||||||
pass
|
|
||||||
elif (major, minor) >= (2, 7):
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
_install_requires.append("xformers==0.0.30")
|
_install_requires.append("xformers==0.0.30")
|
||||||
|
|||||||
@@ -7,8 +7,6 @@ from typing import Literal
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from axolotl.cli.cloud.base import Cloud
|
|
||||||
from axolotl.cli.cloud.baseten import BasetenCloud
|
|
||||||
from axolotl.cli.cloud.modal_ import ModalCloud
|
from axolotl.cli.cloud.modal_ import ModalCloud
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -40,15 +38,8 @@ def do_cli_train(
|
|||||||
cwd=None,
|
cwd=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
provider = cloud_cfg.provider or "modal"
|
cloud = ModalCloud(cloud_cfg)
|
||||||
cloud: Cloud | None
|
|
||||||
if provider == "modal":
|
|
||||||
cloud = ModalCloud(cloud_cfg)
|
|
||||||
elif provider == "baseten":
|
|
||||||
cloud = BasetenCloud(cloud_cfg.to_dict())
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported cloud provider: {provider}")
|
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
config_yaml = file.read()
|
config_yaml = file.read()
|
||||||
local_dirs = {}
|
local_dirs = {}
|
||||||
@@ -67,16 +58,8 @@ def do_cli_lm_eval(
|
|||||||
cloud_config: Path | str,
|
cloud_config: Path | str,
|
||||||
config: Path | str,
|
config: Path | str,
|
||||||
) -> None:
|
) -> None:
|
||||||
cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
provider = cloud_cfg.provider or "modal"
|
cloud = ModalCloud(cloud_cfg)
|
||||||
cloud: Cloud | None
|
|
||||||
if provider == "modal":
|
|
||||||
cloud = ModalCloud(cloud_cfg)
|
|
||||||
elif provider == "baseten":
|
|
||||||
cloud = BasetenCloud(cloud_cfg.to_dict())
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported cloud provider: {provider}")
|
|
||||||
|
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
config_yaml = file.read()
|
config_yaml = file.read()
|
||||||
cloud.lm_eval(config_yaml)
|
cloud.lm_eval(config_yaml)
|
||||||
|
|||||||
@@ -1,68 +0,0 @@
|
|||||||
"""Baseten Cloud CLI"""
|
|
||||||
|
|
||||||
import shutil
|
|
||||||
import subprocess # nosec B404
|
|
||||||
import tempfile
|
|
||||||
from os.path import dirname
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from axolotl.cli.cloud.base import Cloud
|
|
||||||
|
|
||||||
|
|
||||||
class BasetenCloud(Cloud):
|
|
||||||
"""Baseten Cloud Axolotl CLI"""
|
|
||||||
|
|
||||||
def __init__(self, config: dict):
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Separate preprocess function for Baseten is not "
|
|
||||||
"implemented and will happen during hte train step."
|
|
||||||
)
|
|
||||||
|
|
||||||
def train(
|
|
||||||
self,
|
|
||||||
config_yaml: str,
|
|
||||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
|
||||||
launcher_args: list[str] | None = None,
|
|
||||||
local_dirs: dict[str, str] | None = None, # pylint: disable=unused-argument
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
config = self.config.copy()
|
|
||||||
config["launcher"] = launcher
|
|
||||||
config["launcher_args"] = launcher_args
|
|
||||||
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
|
|
||||||
yaml.dump(config, cloud_fout)
|
|
||||||
with open(tmp_dir + "/train.yaml", "w", encoding="utf-8") as config_fout:
|
|
||||||
config_fout.write(config_yaml)
|
|
||||||
shutil.copyfile(dirname(__file__) + "/template/run.sh", tmp_dir + "/run.sh")
|
|
||||||
shutil.copyfile(
|
|
||||||
dirname(__file__) + "/template/train_sft.py", tmp_dir + "/train_sft.py"
|
|
||||||
)
|
|
||||||
subprocess.run( # nosec B603 B607
|
|
||||||
["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def lm_eval(
|
|
||||||
self,
|
|
||||||
config_yaml: str,
|
|
||||||
):
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
config = self.config.copy()
|
|
||||||
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
|
|
||||||
yaml.dump(config, cloud_fout)
|
|
||||||
with open(tmp_dir + "/eval.yaml", "w", encoding="utf-8") as config_fout:
|
|
||||||
config_fout.write(config_yaml)
|
|
||||||
shutil.copyfile(
|
|
||||||
dirname(__file__) + "/template/eval.sh", tmp_dir + "/eval.sh"
|
|
||||||
)
|
|
||||||
shutil.copyfile(
|
|
||||||
dirname(__file__) + "/template/eval_sft.py", tmp_dir + "/eval_sft.py"
|
|
||||||
)
|
|
||||||
subprocess.run( # nosec B603 B607
|
|
||||||
["truss", "train", "push", "eval_sft.py"], cwd=tmp_dir, check=False
|
|
||||||
)
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -eux
|
|
||||||
|
|
||||||
export NCCL_SOCKET_IFNAME="^docker0,lo"
|
|
||||||
export NCCL_IB_DISABLE=0
|
|
||||||
export NCCL_TIMEOUT=1800000
|
|
||||||
|
|
||||||
axolotl lm-eval eval.yaml
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
"""
|
|
||||||
Baseten Training Script for Axolotl
|
|
||||||
"""
|
|
||||||
|
|
||||||
# pylint: skip-file
|
|
||||||
import yaml
|
|
||||||
from truss.base import truss_config
|
|
||||||
|
|
||||||
# Import necessary classes from the Baseten Training SDK
|
|
||||||
from truss_train import definitions
|
|
||||||
|
|
||||||
cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
|
|
||||||
gpu = cloud_config.get("gpu", "h100")
|
|
||||||
gpu_count = (
|
|
||||||
1 # int(cloud_config.get("gpu_count", 1)) # only single GPU supported at the moment
|
|
||||||
)
|
|
||||||
node_count = (
|
|
||||||
1 # int(cloud_config.get("node_count", 1)) # only single node support for lmeval
|
|
||||||
)
|
|
||||||
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
|
|
||||||
secrets = cloud_config.get("secrets", [])
|
|
||||||
# launcher = cloud_config.get("launcher", "accelerate")
|
|
||||||
# launcher_args = cloud_config.get("launcher_args", [])
|
|
||||||
script_name = "eval.sh"
|
|
||||||
|
|
||||||
# launcher_args_str = ""
|
|
||||||
# if launcher_args:
|
|
||||||
# launcher_args_str = "-- " + " ".join(launcher_args)
|
|
||||||
|
|
||||||
# 1. Define a base image for your training job
|
|
||||||
# must use torch 2.7.0 for vllm
|
|
||||||
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
|
|
||||||
|
|
||||||
# 2. Define the Runtime Environment for the Training Job
|
|
||||||
# This includes start commands and environment variables.a
|
|
||||||
# Secrets from the baseten workspace like API keys are referenced using
|
|
||||||
# `SecretReference`.
|
|
||||||
|
|
||||||
env_vars = {
|
|
||||||
# "AXOLOTL_LAUNCHER": launcher,
|
|
||||||
# "AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
|
|
||||||
}
|
|
||||||
for secret_name in secrets:
|
|
||||||
env_vars[secret_name] = definitions.SecretReference(name=secret_name)
|
|
||||||
|
|
||||||
training_runtime = definitions.Runtime(
|
|
||||||
start_commands=[ # Example: list of commands to run your training script
|
|
||||||
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
|
|
||||||
],
|
|
||||||
environment_variables=env_vars,
|
|
||||||
cache_config=definitions.CacheConfig(
|
|
||||||
enabled=True,
|
|
||||||
),
|
|
||||||
checkpointing_config=definitions.CheckpointingConfig(
|
|
||||||
enabled=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Define the Compute Resources for the Training Job
|
|
||||||
training_compute = definitions.Compute(
|
|
||||||
node_count=node_count,
|
|
||||||
accelerator=truss_config.AcceleratorSpec(
|
|
||||||
accelerator=truss_config.Accelerator.H100,
|
|
||||||
count=gpu_count,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Define the Training Job
|
|
||||||
# This brings together the image, compute, and runtime configurations.
|
|
||||||
my_training_job = definitions.TrainingJob(
|
|
||||||
image=definitions.Image(base_image=BASE_IMAGE),
|
|
||||||
compute=training_compute,
|
|
||||||
runtime=training_runtime,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# This config will be pushed using the Truss CLI.
|
|
||||||
# The association of the job to the project happens at the time of push.
|
|
||||||
first_project_with_job = definitions.TrainingProject(
|
|
||||||
name=project_name, job=my_training_job
|
|
||||||
)
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -eux
|
|
||||||
|
|
||||||
export NCCL_SOCKET_IFNAME="^docker0,lo"
|
|
||||||
export NCCL_IB_DISABLE=0
|
|
||||||
export NCCL_TIMEOUT=1800000
|
|
||||||
|
|
||||||
axolotl preprocess train.yaml
|
|
||||||
axolotl train train.yaml --launcher ${AXOLOTL_LAUNCHER} ${AXOLOTL_LAUNCHER_ARGS}
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
"""
|
|
||||||
Baseten Training Script for Axolotl
|
|
||||||
"""
|
|
||||||
|
|
||||||
# pylint: skip-file
|
|
||||||
import yaml
|
|
||||||
from truss.base import truss_config
|
|
||||||
|
|
||||||
# Import necessary classes from the Baseten Training SDK
|
|
||||||
from truss_train import definitions
|
|
||||||
|
|
||||||
cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
|
|
||||||
gpu = cloud_config.get("gpu", "h100")
|
|
||||||
gpu_count = int(cloud_config.get("gpu_count", 1))
|
|
||||||
node_count = int(cloud_config.get("node_count", 1))
|
|
||||||
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
|
|
||||||
secrets = cloud_config.get("secrets", [])
|
|
||||||
launcher = cloud_config.get("launcher", "accelerate")
|
|
||||||
launcher_args = cloud_config.get("launcher_args", [])
|
|
||||||
script_name = "run.sh"
|
|
||||||
|
|
||||||
launcher_args_str = ""
|
|
||||||
if launcher_args:
|
|
||||||
launcher_args_str = "-- " + " ".join(launcher_args)
|
|
||||||
|
|
||||||
# 1. Define a base image for your training job
|
|
||||||
# must use torch 2.7.0 for vllm
|
|
||||||
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
|
|
||||||
|
|
||||||
# 2. Define the Runtime Environment for the Training Job
|
|
||||||
# This includes start commands and environment variables.a
|
|
||||||
# Secrets from the baseten workspace like API keys are referenced using
|
|
||||||
# `SecretReference`.
|
|
||||||
|
|
||||||
env_vars = {
|
|
||||||
"AXOLOTL_LAUNCHER": launcher,
|
|
||||||
"AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
|
|
||||||
}
|
|
||||||
for secret_name in secrets:
|
|
||||||
env_vars[secret_name] = definitions.SecretReference(name=secret_name)
|
|
||||||
|
|
||||||
training_runtime = definitions.Runtime(
|
|
||||||
start_commands=[ # Example: list of commands to run your training script
|
|
||||||
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
|
|
||||||
],
|
|
||||||
environment_variables=env_vars,
|
|
||||||
cache_config=definitions.CacheConfig(
|
|
||||||
enabled=True,
|
|
||||||
),
|
|
||||||
checkpointing_config=definitions.CheckpointingConfig(
|
|
||||||
enabled=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Define the Compute Resources for the Training Job
|
|
||||||
training_compute = definitions.Compute(
|
|
||||||
node_count=node_count,
|
|
||||||
accelerator=truss_config.AcceleratorSpec(
|
|
||||||
accelerator=truss_config.Accelerator.H100,
|
|
||||||
count=gpu_count,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Define the Training Job
|
|
||||||
# This brings together the image, compute, and runtime configurations.
|
|
||||||
my_training_job = definitions.TrainingJob(
|
|
||||||
image=definitions.Image(base_image=BASE_IMAGE),
|
|
||||||
compute=training_compute,
|
|
||||||
runtime=training_runtime,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# This config will be pushed using the Truss CLI.
|
|
||||||
# The association of the job to the project happens at the time of push.
|
|
||||||
first_project_with_job = definitions.TrainingProject(
|
|
||||||
name=project_name, job=my_training_job
|
|
||||||
)
|
|
||||||
@@ -43,10 +43,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
safe_serialization=safe_serialization,
|
safe_serialization=safe_serialization,
|
||||||
progressbar=True,
|
progressbar=True,
|
||||||
)
|
)
|
||||||
tokenizer.save_pretrained(
|
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
str(Path(cfg.output_dir) / "merged"),
|
|
||||||
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
|
||||||
)
|
|
||||||
|
|
||||||
if processor:
|
if processor:
|
||||||
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
|
|||||||
@@ -84,6 +84,5 @@ def do_quantize(
|
|||||||
str(Path(output_dir) / "quantized"),
|
str(Path(output_dir) / "quantized"),
|
||||||
safe_serialization=False,
|
safe_serialization=False,
|
||||||
progressbar=True,
|
progressbar=True,
|
||||||
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
|
||||||
)
|
)
|
||||||
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")
|
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")
|
||||||
|
|||||||
@@ -24,7 +24,9 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import TrainerCallback
|
from transformers import (
|
||||||
|
TrainerCallback,
|
||||||
|
)
|
||||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
@@ -36,7 +38,6 @@ from axolotl.utils.callbacks import (
|
|||||||
SaveModelOnFirstStepCallback,
|
SaveModelOnFirstStepCallback,
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||||
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
|
|
||||||
from axolotl.utils.distributed import build_parallelism_config
|
from axolotl.utils.distributed import build_parallelism_config
|
||||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
@@ -145,12 +146,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
profiler_steps_start=self.cfg.profiler_steps_start,
|
profiler_steps_start=self.cfg.profiler_steps_start,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if self.cfg.include_tkps:
|
|
||||||
callbacks.append(
|
|
||||||
TokensPerSecondCallback(
|
|
||||||
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -517,7 +512,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
self.cfg.eval_batch_size
|
self.cfg.eval_batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
training_args_kwargs["include_tkps"] = self.cfg.include_tkps
|
|
||||||
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
||||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||||
|
|
||||||
|
|||||||
@@ -404,9 +404,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
)
|
)
|
||||||
trainer = self.hook_post_create_trainer(trainer)
|
trainer = self.hook_post_create_trainer(trainer)
|
||||||
# if the trainer has the `axolotl_cfg` property, set it
|
|
||||||
if hasattr(trainer, "axolotl_cfg"):
|
|
||||||
trainer.axolotl_cfg = self.cfg
|
|
||||||
for callback in self.get_post_trainer_create_callbacks(trainer):
|
for callback in self.get_post_trainer_create_callbacks(trainer):
|
||||||
trainer.add_callback(callback)
|
trainer.add_callback(callback)
|
||||||
|
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ from axolotl.core.trainers.utils import (
|
|||||||
)
|
)
|
||||||
from axolotl.utils import get_not_null
|
from axolotl.utils import get_not_null
|
||||||
from axolotl.utils.bench import get_gpu_memory_usage
|
from axolotl.utils.bench import get_gpu_memory_usage
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
@@ -64,15 +63,6 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
tag_names = ["axolotl"]
|
tag_names = ["axolotl"]
|
||||||
_axolotl_cfg: DictDefault | None = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def axolotl_cfg(self):
|
|
||||||
return self._axolotl_cfg
|
|
||||||
|
|
||||||
@axolotl_cfg.setter
|
|
||||||
def axolotl_cfg(self, cfg):
|
|
||||||
self._axolotl_cfg = cfg
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -88,6 +78,7 @@ class AxolotlTrainer(
|
|||||||
self._signature_columns = None # workaround for pylint
|
self._signature_columns = None # workaround for pylint
|
||||||
|
|
||||||
super().__init__(*_args, **kwargs)
|
super().__init__(*_args, **kwargs)
|
||||||
|
|
||||||
self.train_data_collator = self.data_collator
|
self.train_data_collator = self.data_collator
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
@@ -336,17 +327,6 @@ class AxolotlTrainer(
|
|||||||
# outputs = model(**inputs)
|
# outputs = model(**inputs)
|
||||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||||
# return (loss, outputs) if return_outputs else loss
|
# return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
# track number of tokens for tokens per second calculation
|
|
||||||
if self.args.include_tkps:
|
|
||||||
inputs_key = "labels" if "labels" in inputs else "input_ids"
|
|
||||||
if hasattr(self.state, "num_tokens"):
|
|
||||||
self.state.num_tokens = (
|
|
||||||
self.state.num_tokens + (inputs[inputs_key] != -100).sum()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.state.num_tokens = (inputs[inputs_key] != -100).sum()
|
|
||||||
|
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
return self.orpo_compute_loss(
|
return self.orpo_compute_loss(
|
||||||
model,
|
model,
|
||||||
@@ -546,6 +526,9 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
super().create_accelerator_and_postprocess()
|
super().create_accelerator_and_postprocess()
|
||||||
|
|
||||||
|
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
|
||||||
|
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
|
||||||
|
|
||||||
if self.is_fsdp_enabled:
|
if self.is_fsdp_enabled:
|
||||||
if (
|
if (
|
||||||
"limit_all_gathers" in self.args.fsdp_config
|
"limit_all_gathers" in self.args.fsdp_config
|
||||||
@@ -593,19 +576,12 @@ class AxolotlTrainer(
|
|||||||
# Add memory usage
|
# Add memory usage
|
||||||
try:
|
try:
|
||||||
active, allocated, reserved = get_gpu_memory_usage()
|
active, allocated, reserved = get_gpu_memory_usage()
|
||||||
logs["memory/max_active (GiB)"] = round(active, 2)
|
logs["memory/max_mem_active(gib)"] = round(active, 2)
|
||||||
logs["memory/max_allocated (GiB)"] = round(allocated, 2)
|
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
|
||||||
logs["memory/device_reserved (GiB)"] = round(reserved, 2)
|
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
|
||||||
except (ValueError, TypeError, FileNotFoundError):
|
except (ValueError, TypeError, FileNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if self.args.include_tkps and train_eval == "train":
|
|
||||||
# each rank will log its own tokens per second
|
|
||||||
# for logging_steps > 1 we obtain a moving average of this metric
|
|
||||||
logs["tokens_per_second_per_gpu"] = round(
|
|
||||||
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
|
||||||
)
|
|
||||||
|
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
|
|
||||||
return super().log(logs, start_time)
|
return super().log(logs, start_time)
|
||||||
@@ -681,11 +657,6 @@ class AxolotlTrainer(
|
|||||||
LOG.info(
|
LOG.info(
|
||||||
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
||||||
)
|
)
|
||||||
save_jinja_files = True
|
self.data_collator.tokenizer.save_pretrained(output_dir)
|
||||||
if self.axolotl_cfg:
|
|
||||||
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
|
|
||||||
self.data_collator.tokenizer.save_pretrained(
|
|
||||||
output_dir, save_jinja_files=save_jinja_files
|
|
||||||
)
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||||
|
|||||||
@@ -49,12 +49,6 @@ class AxolotlTrainingMixins:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use real batches for efficient training."},
|
metadata={"help": "Use real batches for efficient training."},
|
||||||
)
|
)
|
||||||
include_tkps: bool = field(
|
|
||||||
default=True,
|
|
||||||
metadata={
|
|
||||||
"help": "Whether to include tokens per second in the training metrics."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
eval_sample_packing: Optional[bool] = field(
|
eval_sample_packing: Optional[bool] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Use sample packing for efficient evals."},
|
metadata={"help": "Use sample packing for efficient evals."},
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -149,11 +149,6 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
|||||||
|
|
||||||
return MistralAttention
|
return MistralAttention
|
||||||
|
|
||||||
if model_type == "gemma3_text":
|
|
||||||
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
|
|
||||||
|
|
||||||
return Gemma3Attention
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Dynamically import the module and attention class
|
# Dynamically import the module and attention class
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
|
|||||||
@@ -416,9 +416,7 @@ def save_initial_configs(
|
|||||||
|
|
||||||
# Pre-save the tokenizer and model configs
|
# Pre-save the tokenizer and model configs
|
||||||
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
|
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
|
||||||
tokenizer.save_pretrained(
|
tokenizer.save_pretrained(str(output_dir))
|
||||||
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
|
|
||||||
)
|
|
||||||
if hasattr(model, "config"):
|
if hasattr(model, "config"):
|
||||||
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
||||||
model.config.save_pretrained(str(output_dir))
|
model.config.save_pretrained(str(output_dir))
|
||||||
@@ -594,9 +592,6 @@ def train(
|
|||||||
|
|
||||||
# Save the trained model and cleanup
|
# Save the trained model and cleanup
|
||||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||||
tokenizer.save_pretrained(
|
|
||||||
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
|
|
||||||
)
|
|
||||||
create_model_card(cfg, trainer)
|
create_model_card(cfg, trainer)
|
||||||
if not cfg.use_ray:
|
if not cfg.use_ray:
|
||||||
cleanup_distributed()
|
cleanup_distributed()
|
||||||
|
|||||||
@@ -60,14 +60,13 @@ def gpu_memory_usage_all(device=0):
|
|||||||
active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3
|
active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3
|
||||||
allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3
|
allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3
|
||||||
reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3
|
reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3
|
||||||
torch.cuda.reset_peak_memory_stats(device)
|
|
||||||
return active, allocated, reserved
|
return active, allocated, reserved
|
||||||
|
|
||||||
|
|
||||||
def mps_memory_usage_all():
|
def mps_memory_usage_all():
|
||||||
active = torch.mps.current_allocated_memory() / 1024.0**3
|
usage = torch.mps.current_allocated_memory() / 1024.0**3
|
||||||
allocated = torch.mps.driver_allocated_memory() / 1024.0**3
|
reserved = torch.mps.driver_allocated_memory() / 1024.0**3
|
||||||
return active, allocated, 0
|
return usage, reserved - usage, 0
|
||||||
|
|
||||||
|
|
||||||
def npu_memory_usage_all(device=0):
|
def npu_memory_usage_all(device=0):
|
||||||
|
|||||||
@@ -1,62 +0,0 @@
|
|||||||
"""A callback for calculating tokens per second during training."""
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import (
|
|
||||||
TrainerCallback,
|
|
||||||
TrainerControl,
|
|
||||||
TrainerState,
|
|
||||||
TrainingArguments,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TokensPerSecondCallback(TrainerCallback):
|
|
||||||
"""
|
|
||||||
A callback to measure and log tokens per second during training.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, tensor_parallel_size, context_parallel_size):
|
|
||||||
super().__init__()
|
|
||||||
self.step_time = 0.0
|
|
||||||
self.start_time = 0.0
|
|
||||||
self.non_data_parallel_size = 1
|
|
||||||
if tensor_parallel_size is not None:
|
|
||||||
self.non_data_parallel_size *= tensor_parallel_size
|
|
||||||
if context_parallel_size is not None:
|
|
||||||
self.non_data_parallel_size *= context_parallel_size
|
|
||||||
|
|
||||||
def on_step_begin(
|
|
||||||
self,
|
|
||||||
args: TrainingArguments,
|
|
||||||
state: TrainerState,
|
|
||||||
control: TrainerControl,
|
|
||||||
**kwargs,
|
|
||||||
): # pylint: disable=unused-argument
|
|
||||||
self.start_time = time.perf_counter()
|
|
||||||
state.last_tokens_per_second = torch.zeros(1)
|
|
||||||
|
|
||||||
def on_step_end(
|
|
||||||
self,
|
|
||||||
args: TrainingArguments,
|
|
||||||
state: TrainerState,
|
|
||||||
control: TrainerControl,
|
|
||||||
**kwargs,
|
|
||||||
): # pylint: disable=unused-argument
|
|
||||||
step_time = time.perf_counter() - self.start_time
|
|
||||||
num_tokens_per_device = state.num_tokens.clone()
|
|
||||||
# non data parallel groups have duplicated tokens, so we avoid double-counting
|
|
||||||
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
|
|
||||||
state.last_tokens_per_second = num_tokens_per_device / step_time
|
|
||||||
|
|
||||||
def on_log(
|
|
||||||
self,
|
|
||||||
args: TrainingArguments,
|
|
||||||
state: TrainerState,
|
|
||||||
control: TrainerControl,
|
|
||||||
logs=None,
|
|
||||||
**kwargs,
|
|
||||||
): # pylint: disable=unused-argument
|
|
||||||
# after logging, clear the running metrics
|
|
||||||
state.last_tokens_per_second.zero_()
|
|
||||||
state.num_tokens = 0
|
|
||||||
@@ -77,7 +77,7 @@ def resolve_dtype(cfg):
|
|||||||
if cfg.device == "mps":
|
if cfg.device == "mps":
|
||||||
cfg.load_in_8bit = False
|
cfg.load_in_8bit = False
|
||||||
cfg.tf32 = False
|
cfg.tf32 = False
|
||||||
if cfg.bf16 and cfg.fp16 is not False:
|
if cfg.bf16:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
cfg.bf16 = False
|
cfg.bf16 = False
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from axolotl.utils.data.shared import (
|
|||||||
save_preprocessed_dataset,
|
save_preprocessed_dataset,
|
||||||
try_load_from_hub,
|
try_load_from_hub,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.data.streaming import wrap_streaming_sft_dataset
|
||||||
from axolotl.utils.data.utils import (
|
from axolotl.utils.data.utils import (
|
||||||
deduplicate_and_log_datasets,
|
deduplicate_and_log_datasets,
|
||||||
handle_long_seq_in_dataset,
|
handle_long_seq_in_dataset,
|
||||||
@@ -73,7 +74,7 @@ def _prepare_standard_dataset(
|
|||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
processor: ProcessorMixin | None,
|
processor: ProcessorMixin | None,
|
||||||
preprocess_iterable: bool,
|
preprocess_iterable: bool,
|
||||||
) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]:
|
) -> tuple[Dataset | IterableDataset, Dataset | None, int, list[Prompter | None]]:
|
||||||
"""Prepare standard (non-pretraining) datasets."""
|
"""Prepare standard (non-pretraining) datasets."""
|
||||||
|
|
||||||
def _load_datasets():
|
def _load_datasets():
|
||||||
@@ -118,7 +119,14 @@ def _prepare_standard_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Calculate total number of training steps
|
# Calculate total number of training steps
|
||||||
if cfg.max_steps:
|
# For streaming datasets, we must use max_steps
|
||||||
|
if isinstance(train_dataset, IterableDataset):
|
||||||
|
if not cfg.max_steps:
|
||||||
|
raise ValueError(
|
||||||
|
"When using streaming datasets, you must set max_steps in your config"
|
||||||
|
)
|
||||||
|
total_num_steps = cfg.max_steps
|
||||||
|
elif cfg.max_steps:
|
||||||
total_num_steps = min(
|
total_num_steps = min(
|
||||||
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||||
)
|
)
|
||||||
@@ -342,14 +350,18 @@ def _load_raw_datasets(
|
|||||||
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
||||||
else:
|
else:
|
||||||
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
||||||
if cfg.sample_packing:
|
|
||||||
|
# Skip packing processing for streaming datasets - they handle it differently
|
||||||
|
if cfg.sample_packing and not isinstance(dataset, IterableDataset):
|
||||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||||
|
|
||||||
# Save the prepared dataset
|
# Skip saving for streaming datasets as they can't be cached
|
||||||
dataset_hash = generate_dataset_hash_from_config(
|
if not isinstance(dataset, IterableDataset):
|
||||||
cfg, datasets_configs, tokenizer.name_or_path
|
# Save the prepared dataset
|
||||||
)
|
dataset_hash = generate_dataset_hash_from_config(
|
||||||
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
|
cfg, datasets_configs, tokenizer.name_or_path
|
||||||
|
)
|
||||||
|
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
|
||||||
|
|
||||||
return dataset, prompters
|
return dataset, prompters
|
||||||
|
|
||||||
@@ -365,8 +377,10 @@ def _load_and_process_single_dataset(
|
|||||||
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
||||||
"""Load and process a single dataset based on the passed config."""
|
"""Load and process a single dataset based on the passed config."""
|
||||||
# Load the dataset
|
# Load the dataset
|
||||||
|
# Use streaming if enabled in config or if using iterable preprocessing
|
||||||
|
use_streaming = cfg.streaming or preprocess_iterable
|
||||||
dataset = load_dataset_with_config(
|
dataset = load_dataset_with_config(
|
||||||
dataset_config, cfg.hf_use_auth_token, streaming=preprocess_iterable
|
dataset_config, cfg.hf_use_auth_token, streaming=use_streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse dataset type
|
# Parse dataset type
|
||||||
@@ -391,16 +405,63 @@ def _load_and_process_single_dataset(
|
|||||||
num_shards=dataset_config.shards, index=shards_idx
|
num_shards=dataset_config.shards, index=shards_idx
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply dataset wrapper
|
# For streaming datasets, we need to handle tokenization differently
|
||||||
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
if isinstance(dataset, IterableDataset):
|
||||||
dataset_config=dataset_config,
|
# Use pretraining's approach for multipack streaming
|
||||||
tokenizer=tokenizer,
|
if cfg.sample_packing:
|
||||||
cfg=cfg,
|
# Create the dataset wrapper function once
|
||||||
dataset_base_type=d_base_type,
|
def ds_wrapper_fn(dataset=None):
|
||||||
dataset=dataset,
|
wrapped_dataset, prompter = get_dataset_wrapper(
|
||||||
dataset_prompt_style=d_prompt_style,
|
dataset_config=dataset_config,
|
||||||
processor=processor,
|
tokenizer=tokenizer,
|
||||||
)
|
cfg=cfg,
|
||||||
|
dataset_base_type=d_base_type,
|
||||||
|
dataset=dataset,
|
||||||
|
dataset_prompt_style=d_prompt_style,
|
||||||
|
processor=processor,
|
||||||
|
)
|
||||||
|
return wrapped_dataset, prompter
|
||||||
|
|
||||||
|
# Use pretraining wrapper for efficient streaming SFT with packing
|
||||||
|
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||||
|
|
||||||
|
dataset_wrapper = wrap_pretraining_dataset(
|
||||||
|
dataset,
|
||||||
|
tokenizer,
|
||||||
|
cfg,
|
||||||
|
ds_wrapper_fn,
|
||||||
|
max_tokens=cfg.sequence_len,
|
||||||
|
batch_size=cfg.micro_batch_size,
|
||||||
|
seed=cfg.seed,
|
||||||
|
buffer_size=cfg.pretrain_multipack_buffer_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use regular streaming wrapper
|
||||||
|
dataset_wrapper = wrap_streaming_sft_dataset(
|
||||||
|
dataset,
|
||||||
|
tokenizer,
|
||||||
|
cfg,
|
||||||
|
dataset_config,
|
||||||
|
d_base_type,
|
||||||
|
d_prompt_style,
|
||||||
|
processor,
|
||||||
|
max_tokens=cfg.sequence_len,
|
||||||
|
buffer_size=10_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For streaming, we don't have a specific prompter
|
||||||
|
dataset_prompter = None
|
||||||
|
else:
|
||||||
|
# Apply dataset wrapper for regular datasets
|
||||||
|
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
||||||
|
dataset_config=dataset_config,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
cfg=cfg,
|
||||||
|
dataset_base_type=d_base_type,
|
||||||
|
dataset=dataset,
|
||||||
|
dataset_prompt_style=d_prompt_style,
|
||||||
|
processor=processor,
|
||||||
|
)
|
||||||
|
|
||||||
return dataset_wrapper, dataset_prompter
|
return dataset_wrapper, dataset_prompter
|
||||||
|
|
||||||
|
|||||||
@@ -524,7 +524,9 @@ def generate_dataset_hash_from_config(
|
|||||||
return str(md5(config_str))
|
return str(md5(config_str))
|
||||||
|
|
||||||
|
|
||||||
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
def merge_datasets(
|
||||||
|
datasets: list[Dataset | IterableDataset], cfg: DictDefault
|
||||||
|
) -> Dataset | IterableDataset:
|
||||||
"""Merge multiple datasets into one with optional shuffling.
|
"""Merge multiple datasets into one with optional shuffling.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -534,6 +536,41 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
|||||||
Returns:
|
Returns:
|
||||||
Merged dataset.
|
Merged dataset.
|
||||||
"""
|
"""
|
||||||
|
# Check if we're dealing with streaming datasets
|
||||||
|
if any(isinstance(ds, IterableDataset) for ds in datasets):
|
||||||
|
# All datasets must be streaming for merging
|
||||||
|
if not all(isinstance(ds, IterableDataset) for ds in datasets):
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot mix streaming and non-streaming datasets. "
|
||||||
|
"Either all datasets must be streaming or none."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(datasets) == 1:
|
||||||
|
ds = datasets[0]
|
||||||
|
# Streaming datasets handle shuffling differently
|
||||||
|
if cfg.shuffle_merged_datasets and not cfg.curriculum_sampling:
|
||||||
|
return ds.shuffle(seed=cfg.seed, buffer_size=10_000)
|
||||||
|
return ds
|
||||||
|
|
||||||
|
# Merge streaming datasets
|
||||||
|
LOG.info("Merging streaming datasets...")
|
||||||
|
from datasets import interleave_datasets
|
||||||
|
|
||||||
|
# For streaming, we interleave datasets instead of concatenating
|
||||||
|
merged_dataset = interleave_datasets(datasets)
|
||||||
|
|
||||||
|
if cfg.shuffle_merged_datasets:
|
||||||
|
LOG.debug("Shuffling merged streaming datasets...")
|
||||||
|
if cfg.curriculum_sampling:
|
||||||
|
LOG.warning(
|
||||||
|
"Shuffling merged datasets with curriculum sampling is not recommended. "
|
||||||
|
"This will randomize the order of samples."
|
||||||
|
)
|
||||||
|
merged_dataset = merged_dataset.shuffle(seed=cfg.seed, buffer_size=10_000)
|
||||||
|
|
||||||
|
return merged_dataset
|
||||||
|
|
||||||
|
# Original logic for non-streaming datasets
|
||||||
if len(datasets) == 1:
|
if len(datasets) == 1:
|
||||||
ds = datasets[0]
|
ds = datasets[0]
|
||||||
|
|
||||||
|
|||||||
150
src/axolotl/utils/data/streaming.py
Normal file
150
src/axolotl/utils/data/streaming.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
"""Utilities for handling streaming datasets."""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from datasets import Dataset, IterableDataset
|
||||||
|
from torch.utils.data import RandomSampler
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
from axolotl.utils.trainer import add_position_ids
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_streaming_sft_dataset(
|
||||||
|
dataset: IterableDataset,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
cfg,
|
||||||
|
dataset_config,
|
||||||
|
d_base_type: str,
|
||||||
|
d_prompt_style: str | None,
|
||||||
|
processor: Any | None,
|
||||||
|
max_tokens: int = 2048,
|
||||||
|
buffer_size: int = 10_000,
|
||||||
|
) -> IterableDataset:
|
||||||
|
"""
|
||||||
|
Wrap a streaming SFT dataset with tokenization and optional packing.
|
||||||
|
|
||||||
|
This is similar to wrap_pretraining_dataset but for SFT datasets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The streaming dataset to wrap
|
||||||
|
tokenizer: Tokenizer to use
|
||||||
|
cfg: Configuration object
|
||||||
|
dataset_config: Dataset configuration
|
||||||
|
d_base_type: Base dataset type
|
||||||
|
d_prompt_style: Prompt style
|
||||||
|
processor: Optional processor for multimodal
|
||||||
|
max_tokens: Maximum sequence length
|
||||||
|
buffer_size: Buffer size for shuffling
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wrapped streaming dataset ready for training
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from axolotl.utils.data.wrappers import get_dataset_wrapper
|
||||||
|
|
||||||
|
# Apply shuffling if configured
|
||||||
|
if cfg.shuffle_merged_datasets:
|
||||||
|
LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}")
|
||||||
|
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
|
||||||
|
|
||||||
|
# For streaming datasets, we need to get column names from the first sample
|
||||||
|
remove_columns = []
|
||||||
|
for first_row in dataset:
|
||||||
|
remove_columns = list(first_row.keys())
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reset dataset after peeking
|
||||||
|
if cfg.shuffle_merged_datasets:
|
||||||
|
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
|
||||||
|
|
||||||
|
# Define the encoding function - always add position_ids for compatibility
|
||||||
|
if cfg.sample_packing:
|
||||||
|
# For sample packing, we need to handle position_ids
|
||||||
|
def encode_streaming_packed(examples: Dict[str, List]) -> Dict[str, List]:
|
||||||
|
"""Encode examples for streaming with sample packing."""
|
||||||
|
# Convert the batch dict to a temporary Dataset for processing
|
||||||
|
temp_dataset = Dataset.from_dict(examples)
|
||||||
|
|
||||||
|
# Apply the dataset wrapper to tokenize
|
||||||
|
wrapped_dataset, _ = get_dataset_wrapper(
|
||||||
|
dataset_config=dataset_config,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
cfg=cfg,
|
||||||
|
dataset_base_type=d_base_type,
|
||||||
|
dataset=temp_dataset,
|
||||||
|
dataset_prompt_style=d_prompt_style,
|
||||||
|
processor=processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to dict for processing
|
||||||
|
result = {}
|
||||||
|
if hasattr(wrapped_dataset, "to_dict"):
|
||||||
|
result = wrapped_dataset.to_dict()
|
||||||
|
else:
|
||||||
|
for key in wrapped_dataset.column_names:
|
||||||
|
result[key] = wrapped_dataset[key]
|
||||||
|
|
||||||
|
# Add position_ids using the existing function
|
||||||
|
result = add_position_ids(result)
|
||||||
|
|
||||||
|
# For multipack attention, we may need to drop attention_mask
|
||||||
|
if cfg.pretrain_multipack_attn and "attention_mask" in result:
|
||||||
|
del result["attention_mask"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
encode_fn = encode_streaming_packed
|
||||||
|
else:
|
||||||
|
# Regular encoding without packing - still add position_ids for compatibility
|
||||||
|
def encode_streaming(examples: Dict[str, List]) -> Dict[str, List]:
|
||||||
|
"""Encode examples for streaming."""
|
||||||
|
# Convert the batch dict to a temporary Dataset for processing
|
||||||
|
temp_dataset = Dataset.from_dict(examples)
|
||||||
|
|
||||||
|
# Apply the dataset wrapper to tokenize
|
||||||
|
wrapped_dataset, _ = get_dataset_wrapper(
|
||||||
|
dataset_config=dataset_config,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
cfg=cfg,
|
||||||
|
dataset_base_type=d_base_type,
|
||||||
|
dataset=temp_dataset,
|
||||||
|
dataset_prompt_style=d_prompt_style,
|
||||||
|
processor=processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to dict format
|
||||||
|
result = {}
|
||||||
|
if hasattr(wrapped_dataset, "to_dict"):
|
||||||
|
result = wrapped_dataset.to_dict()
|
||||||
|
else:
|
||||||
|
for key in wrapped_dataset.column_names:
|
||||||
|
result[key] = wrapped_dataset[key]
|
||||||
|
|
||||||
|
# Add position_ids even without packing for compatibility
|
||||||
|
result = add_position_ids(result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
encode_fn = encode_streaming
|
||||||
|
|
||||||
|
# Map the encoding function over the streaming dataset
|
||||||
|
dataset = dataset.map(
|
||||||
|
encode_fn,
|
||||||
|
batched=True,
|
||||||
|
batch_size=buffer_size,
|
||||||
|
remove_columns=remove_columns,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set format for PyTorch
|
||||||
|
dataset = dataset.with_format("torch")
|
||||||
|
|
||||||
|
return dataset
|
||||||
@@ -178,8 +178,8 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
|||||||
|
|
||||||
|
|
||||||
def handle_long_seq_in_dataset(
|
def handle_long_seq_in_dataset(
|
||||||
dataset: Dataset, sequence_len: int, cfg: DictDefault
|
dataset: Dataset | IterableDataset, sequence_len: int, cfg: DictDefault
|
||||||
) -> Dataset:
|
) -> Dataset | IterableDataset:
|
||||||
"""Remove sequences longer than configured maximum from dataset.
|
"""Remove sequences longer than configured maximum from dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -190,7 +190,14 @@ def handle_long_seq_in_dataset(
|
|||||||
Returns:
|
Returns:
|
||||||
Filtered dataset with long sequences removed.
|
Filtered dataset with long sequences removed.
|
||||||
"""
|
"""
|
||||||
if "input_ids" not in dataset.column_names:
|
# Streaming datasets don't support filtering the same way
|
||||||
|
if isinstance(dataset, IterableDataset):
|
||||||
|
LOG.info(
|
||||||
|
"Streaming dataset detected - long sequence filtering will be done on-the-fly"
|
||||||
|
)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
if not hasattr(dataset, "column_names") or "input_ids" not in dataset.column_names:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
|
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
|
||||||
"expected for reward modeling."
|
"expected for reward modeling."
|
||||||
|
|||||||
@@ -244,6 +244,12 @@ class AxolotlInputConfig(
|
|||||||
dataloader_num_workers: int | None = None
|
dataloader_num_workers: int | None = None
|
||||||
dataloader_prefetch_factor: int | None = None
|
dataloader_prefetch_factor: int | None = None
|
||||||
dataloader_drop_last: bool | None = None
|
dataloader_drop_last: bool | None = None
|
||||||
|
streaming: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Enable streaming mode for training datasets to reduce memory usage and enable training on datasets larger than memory"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
accelerator_config: dict[str, Any] | None = None
|
accelerator_config: dict[str, Any] | None = None
|
||||||
|
|
||||||
@@ -830,15 +836,10 @@ class AxolotlInputConfig(
|
|||||||
include_tokens_per_second: bool | None = Field(
|
include_tokens_per_second: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "bool of whether to report tokens per second at the end of training. This is not supported with pre-training datasets."
|
"description": "bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time."
|
||||||
},
|
|
||||||
)
|
|
||||||
include_tkps: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "bool of whether to report tokens per second during training by measuring throughput of non-padding tokens."
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
neftune_noise_alpha: float | None = Field(
|
neftune_noise_alpha: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
@@ -59,21 +59,16 @@ class ModelInputConfig(BaseModel):
|
|||||||
processor_type: str | None = Field(
|
processor_type: str | None = Field(
|
||||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||||
)
|
)
|
||||||
tokenizer_save_jinja_files: bool | None = Field(
|
|
||||||
default=True, # match the default behavior from transformers
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Whether to save jinja files for tokenizer, transformers default is True"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
trust_remote_code: bool | None = Field(
|
trust_remote_code: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Trust remote code for untrusted source"},
|
json_schema_extra={"description": "Trust remote code for untrusted source"},
|
||||||
)
|
)
|
||||||
|
|
||||||
experimental_skip_move_to_device: bool | None = Field(
|
experimental_skip_move_to_device: bool | None = Field(
|
||||||
default=True,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Don't move the model to the device before sharding. Set to `false` to revert to legacy behavior."
|
"description": "Don't move the model to the device before sharding. "
|
||||||
|
"This is an experimental feature that may be included in the future as the default."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1074,6 +1074,24 @@ class PretrainingValidationMixin:
|
|||||||
data["accelerator_config"]["dispatch_batches"] = False
|
data["accelerator_config"]["dispatch_batches"] = False
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_pretraining_split_batches_accelerate(cls, data):
|
||||||
|
# alternatively set ACCELERATE_SPLIT_BATCHES=False
|
||||||
|
if data.get("streaming"):
|
||||||
|
accelerator_config = data.get("accelerator_config", {})
|
||||||
|
if not accelerator_config:
|
||||||
|
data["accelerator_config"] = {
|
||||||
|
"split_batches": False,
|
||||||
|
"dispatch_batches": False,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
if accelerator_config.get("split_batches") is None:
|
||||||
|
data["accelerator_config"]["split_batches"] = False
|
||||||
|
if accelerator_config.get("dispatch_batches") is None:
|
||||||
|
data["accelerator_config"]["dispatch_batches"] = False
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ModelCompatibilityValidationMixin:
|
class ModelCompatibilityValidationMixin:
|
||||||
"""Validation methods for specific model compatibility."""
|
"""Validation methods for specific model compatibility."""
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
"""
|
|
||||||
e2e test for saving the tokenizer
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
|
||||||
from axolotl.train import train
|
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
from tests.e2e.utils import check_model_output_exists
|
|
||||||
|
|
||||||
|
|
||||||
def test_tokenizer_no_save_jinja_files(temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"tokenizer_type": "AutoTokenizer",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"load_in_8bit": True,
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0.02,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"chat_template": "chatml",
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 5,
|
|
||||||
"save_first_step": False,
|
|
||||||
"fp16": False,
|
|
||||||
"tokenizer_save_jinja_files": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
dataset_meta = load_datasets(cfg=cfg)
|
|
||||||
|
|
||||||
with patch("axolotl.train.execute_training"):
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
|
|
||||||
check_model_output_exists(temp_dir, cfg)
|
|
||||||
with open(f"{temp_dir}/tokenizer_config.json", "r", encoding="utf-8") as f:
|
|
||||||
tokenizer_config = f.read()
|
|
||||||
assert "chat_template" in tokenizer_config
|
|
||||||
Reference in New Issue
Block a user