Compare commits

...

14 Commits

Author SHA1 Message Date
Wing Lian
1447beb132 make sure to validate the config before normalizing so defaults get set (#2554)
* make sure to validate the config before normalizing so defaults get set

* validation not needed for particular test

* remove duplicate validations

* set qlora correctly
2025-04-24 13:01:43 -04:00
Dan Saunders
66f41ec6f1 disable codecov pr annotations (#2556) 2025-04-24 08:51:51 -04:00
NanoCode012
85053f4bd4 Fix(doc): add delinearize instruction (#2545)
* fix: mention to install pytorch before axolotl

* feat(doc): include instruction to delinearize

* fix: update instruction for delinearize with adapter
2025-04-24 01:03:43 -04:00
Wing Lian
a4d5112ae1 builds for torch 2.7.0 (#2552)
* builds for torch==2.7.0

* use xformers==0.0.29.post3

* no vllm support with torch 2.7

* update default, fix conditional

* no xformers for 270

* no vllm on 2.7.0 for multigpu test too

* remove deprecated verbose arg from scheduler

* 2.7.0 tests on cpu
2025-04-24 00:39:31 -04:00
Wing Lian
0d691cc2a7 add base docker image with pytorch 2.7.0 and variant for cuda 12.8 (#2551)
* add base docker image with pytorch 2.7.0 and variant for cuda 12.8

* my bash is terrible
2025-04-23 14:59:03 -04:00
Dan Saunders
c4053481ff Codecov fixes / improvements (#2549)
* adding codecov reporting

* random change

* codecov fixes

* adding missing dependency

* fix

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
2025-04-23 10:33:30 -04:00
NanoCode012
a6d28d19b1 feat: add glm and glm4 multipack and cce (#2546)
* feat: add glm and glm4 multipack

* feat: add glm4 example

* feat: add cce for glm
2025-04-23 10:27:51 -04:00
Wing Lian
32e335dd51 fix missing host/port for vllm (#2543)
* fix missing host/port for vllm

* set tensor parallel size so it doesn't always default to cli override
2025-04-22 10:16:48 -04:00
Wing Lian
7651550850 make sure to download fixtures for kd test (#2541)
* make sure to download fixtures for kd test

* use same alpaca dataset
2025-04-21 10:31:50 -04:00
Wing Lian
341e95aac9 prevent rate limiting to hf when using dispatch batches (#2536) [skip ci] 2025-04-21 10:31:35 -04:00
Catgat
b882dfb63f Fixed Rex Scheduler Warm Up (#2535) [skip ci]
* Fixed Rex Scheduler Warm Up

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-04-21 10:30:55 -04:00
Wing Lian
b640db1dbc don't run multigpu tests twice, run SP in separate test (#2542)
* don't run multigpu tests twice, run SP in separate test

* fix multiline
2025-04-21 10:24:13 -04:00
Chiwan Park
4ce469d32e fix: upgrade liger to 0.5.8 and use native Gemma3 patches (#2527)
* fix: upgrade liger to 0.5.8 and use native Gemma3 patches

* fix: make lint happy

* doc: update Liger Kernel FLCE support for Gemma 3
2025-04-18 09:57:40 -07:00
Wing Lian
60a8f0958d zero val fix for beta (#2538) 2025-04-17 17:27:19 -07:00
57 changed files with 362 additions and 101 deletions

View File

@@ -46,6 +46,18 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""

View File

@@ -31,6 +31,11 @@ jobs:
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: vllm axolotl_extras: vllm
is_latest: true is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras: vllm
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -93,6 +98,11 @@ jobs:
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -138,7 +148,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.6.0
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:

View File

@@ -45,6 +45,13 @@ jobs:
axolotl_extras: vllm axolotl_extras: vllm
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]
timeout-minutes: 120 timeout-minutes: 120
steps: steps:
@@ -67,6 +74,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run cicd.multigpu modal run cicd.multigpu

View File

@@ -147,6 +147,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run cicd.e2e_tests modal run cicd.e2e_tests

View File

@@ -49,7 +49,7 @@ jobs:
max-parallel: 2 max-parallel: 2
matrix: matrix:
python_version: ["3.11"] python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"] pytorch_version: ["2.4.1", "2.5.1", "2.6.0", "2.7.0"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -109,6 +109,7 @@ jobs:
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
uses: codecov/codecov-action@v5 uses: codecov/codecov-action@v5
with: with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }} flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false fail_ci_if_error: false
@@ -241,6 +242,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run cicd.e2e_tests modal run cicd.e2e_tests
@@ -268,6 +270,12 @@ jobs:
pytorch: 2.5.1 pytorch: 2.5.1
num_gpus: 1 num_gpus: 1
axolotl_extras: vllm axolotl_extras: vllm
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
num_gpus: 1
axolotl_extras:
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -288,6 +296,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run cicd.e2e_tests modal run cicd.e2e_tests

View File

@@ -9,8 +9,7 @@ pytest -v --durations=10 -n8 \
--ignore=tests/patched/ \ --ignore=tests/patched/ \
--ignore=tests/cli \ --ignore=tests/cli \
/workspace/axolotl/tests/ \ /workspace/axolotl/tests/ \
--cov=axolotl \ --cov=axolotl
--cov-report=xml:coverage.xml
# Run lora kernels tests with coverage append # Run lora kernels tests with coverage append
pytest -v --durations=10 \ pytest -v --durations=10 \
@@ -51,11 +50,6 @@ pytest -v --durations=10 \
/workspace/axolotl/tests/e2e/ \ /workspace/axolotl/tests/e2e/ \
--cov=axolotl \ --cov=axolotl \
--cov-append \ --cov-append \
--cov-report=xml:coverage.xml --cov-report=xml:e2e-coverage.xml
# Upload coverage to Codecov codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}
if [ -f e2e-coverage.xml ]; then
codecov -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}
else
echo "Coverage file not found. Coverage report may have failed."
fi

View File

@@ -28,6 +28,7 @@ df_args = {
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""), "NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub", "HF_HOME": "/workspace/data/huggingface-cache/hub",
} }

View File

@@ -29,6 +29,7 @@ df_args = {
"CUDA": os.environ.get("CUDA", "121"), "CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub", "HF_HOME": "/workspace/data/huggingface-cache/hub",
} }

View File

@@ -1,25 +1,23 @@
#!/bin/bash #!/bin/bash
set -e set -e
# only run one test at a time so as not to OOM the GPU
pytest -v --durations=10 -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
# Only run two tests at a time to avoid OOM on GPU (with coverage collection) # Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v -n2 \ pytest -v -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \ /workspace/axolotl/tests/e2e/multigpu/ \
--cov=axolotl \ --cov=axolotl
--cov-report=xml:multigpu-coverage.xml
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/ \ # Run solo tests with coverage append
pytest -v --durations=10 -n1 \
/workspace/axolotl/tests/e2e/multigpu/solo/ \
--cov=axolotl \
--cov-append
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov=axolotl \ --cov=axolotl \
--cov-append \ --cov-append \
--cov-report=xml:multigpu-coverage.xml --cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov # Upload coverage to Codecov
if [ -f multigpu-coverage.xml ]; then codecov upload-process -t $CODECOV_TOKEN -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}
codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}
else
echo "Coverage file not found. Coverage report may have failed."
fi

View File

@@ -49,3 +49,6 @@ comment:
require_changes: no require_changes: no
require_base: no require_base: no
require_head: yes require_head: yes
github_checks:
annotations: false

View File

@@ -37,3 +37,7 @@ RUN git lfs install --skip-repo && \
pip3 install awscli && \ pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working # The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
pip3 install flash-attn==2.7.4.post1; \
fi

View File

@@ -199,6 +199,17 @@ output_dir: # Directory to save evaluation results
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details. See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
### delinearize-llama4
Delinearizes a Llama 4 linearized model into a regular HuggingFace Llama 4 model. This only works with the non-quantized linearized model.
```bash
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
```
This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.
## Legacy CLI Usage ## Legacy CLI Usage
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI: While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:

View File

@@ -19,6 +19,12 @@ This guide covers all the ways you can install and set up Axolotl for your envir
## Installation Methods {#sec-installation-methods} ## Installation Methods {#sec-installation-methods}
::: {.callout-important}
Please make sure to have Pytorch installed before installing Axolotl in your local environment.
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
:::
### PyPI Installation (Recommended) {#sec-pypi} ### PyPI Installation (Recommended) {#sec-pypi}
```{.bash} ```{.bash}

View File

@@ -0,0 +1,62 @@
base_model: THUDM/GLM-4-32B-0414
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_4bit: true
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -26,3 +26,11 @@ Multi-GPU (4xH100) for Llama 4 Scout uses 62.8GB VRAM/GPU @ 4k contenxt length @
### Llama 4 Maverick 17Bx128Experts (400B) ### Llama 4 Maverick 17Bx128Experts (400B)
Coming Soon Coming Soon
## Delinearized Llama 4 Models
We provide a script to delinearize Llama 4 linearized models into regular HuggingFace Llama 4 models.
```bash
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
```

View File

@@ -1,4 +1,5 @@
codecov codecov
codecov-cli
pytest pytest
pytest-cov pytest-cov
pytest-retry pytest-retry

View File

@@ -6,7 +6,7 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
autoawq==0.2.7.post3 autoawq==0.2.7.post3
liger-kernel==0.5.6 liger-kernel==0.5.8
# END section # END section
packaging==23.2 packaging==23.2
@@ -19,6 +19,7 @@ datasets==3.5.0
deepspeed>=0.15.4 deepspeed>=0.15.4
trl==0.16.1 trl==0.16.1
hf_xet==1.0.0 hf_xet==1.0.0
hqq==0.2.5
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer

View File

@@ -51,7 +51,7 @@ def parse_requirements(extras_require_map):
try: try:
torch_version = version("torch") torch_version = version("torch")
except PackageNotFoundError: except PackageNotFoundError:
torch_version = "2.5.1" torch_version = "2.6.0" # default to torch 2.6
_install_requires.append(f"torch=={torch_version}") _install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version) version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
@@ -64,9 +64,15 @@ def parse_requirements(extras_require_map):
else: else:
raise ValueError("Invalid version format") raise ValueError("Invalid version format")
if (major, minor) >= (2, 6): if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post2") # _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.3"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append(
"xformers==0.0.29.post2"
) # vllm needs post2 w torch 2.6
extras_require_map["vllm"] = ["vllm==0.8.3"] extras_require_map["vllm"] = ["vllm==0.8.3"]
elif (major, minor) >= (2, 5): elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))

View File

@@ -39,16 +39,16 @@ class TrainerCliArgs:
class VllmServeCliArgs: class VllmServeCliArgs:
"""Dataclass with CLI arguments for `axolotl vllm-serve` command.""" """Dataclass with CLI arguments for `axolotl vllm-serve` command."""
tensor_parallel_size: int = field( tensor_parallel_size: Optional[int] = field(
default=1, default=None,
metadata={"help": "Number of tensor parallel workers to use."}, metadata={"help": "Number of tensor parallel workers to use."},
) )
host: str = field( host: Optional[str] = field(
default="0.0.0.0", # nosec B104 default=None, # nosec B104
metadata={"help": "Host address to run the server on."}, metadata={"help": "Host address to run the server on."},
) )
port: int = field( port: Optional[int] = field(
default=8000, default=None,
metadata={"help": "Port to run the server on."}, metadata={"help": "Port to run the server on."},
) )
gpu_memory_utilization: Optional[float] = field( gpu_memory_utilization: Optional[float] = field(

View File

@@ -1040,9 +1040,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dataset_processes: if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta: if self.cfg.trl and self.cfg.trl.beta is not None:
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta training_args_kwargs["beta"] = self.cfg.trl.beta
if self.cfg.orpo_alpha: elif self.cfg.rl_beta is not None:
training_args_kwargs["beta"] = self.cfg.rl_beta
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ??? # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha training_args_kwargs["beta"] = self.cfg.orpo_alpha

View File

@@ -40,8 +40,8 @@ class GRPOStrategy:
if trl.use_vllm: if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port
if trl.vllm_server_timeout: if trl.vllm_server_timeout:
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
if trl.vllm_guided_decoding_regex: if trl.vllm_guided_decoding_regex:

View File

@@ -47,6 +47,8 @@ cut_cross_entropy: true
- qwen2 - qwen2
- cohere - cohere
- cohere2 - cohere2
- glm
- glm4
## Citation ## Citation

View File

@@ -0,0 +1,57 @@
"""GLM 4 patch. GLM family inherits from Llama."""
from types import MethodType
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
)
def patch_glm(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
# Set the _PATCH_OPTS in the llama patch file
import cut_cross_entropy.transformers.llama as llama_patch
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
from cut_cross_entropy.transformers.llama import cce_forward
from transformers.models.glm import modeling_glm
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_glm.GlmForCausalLM
), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_glm.GlmForCausalLM.forward = cce_forward
return None
def patch_glm4(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
# Set the _PATCH_OPTS in the llama patch file
import cut_cross_entropy.transformers.llama as llama_patch
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
from cut_cross_entropy.transformers.llama import cce_forward
from transformers.models.glm4 import modeling_glm4
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_glm4.Glm4ForCausalLM
), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_glm4.Glm4ForCausalLM.forward = cce_forward
return None

View File

@@ -20,6 +20,10 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
patch_gemma3, patch_gemma3,
patch_gemma3_text, patch_gemma3_text,
) )
from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import (
patch_glm,
patch_glm4,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import ( from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
patch_llama4, patch_llama4,
patch_llama4_text, patch_llama4_text,
@@ -45,6 +49,8 @@ CUT_CROSS_ENTROPY_MODEL_MAPPING = {
"qwen2": patch_qwen2, "qwen2": patch_qwen2,
"cohere": patch_cohere, "cohere": patch_cohere,
"cohere2": patch_cohere2, "cohere2": patch_cohere2,
"glm": patch_glm,
"glm4": patch_glm4,
} }

View File

@@ -25,7 +25,7 @@ liger_fused_linear_cross_entropy: true
- deepseek_v2 - deepseek_v2
- gemma - gemma
- gemma2 - gemma2
- gemma3 (partial support, no support for FLCE yet) - gemma3
- granite - granite
- jamba - jamba
- llama - llama

View File

@@ -21,7 +21,6 @@ It is designed to be performant, correct, and light-weight.
import inspect import inspect
import logging import logging
import sys import sys
from functools import partial
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
@@ -55,7 +54,6 @@ class LigerPlugin(BasePlugin):
) )
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rms_norm import LigerRMSNorm
@@ -141,38 +139,6 @@ class LigerPlugin(BasePlugin):
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy: if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
from transformers.models.gemma3 import modeling_gemma3
if cfg.liger_rope:
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
def _liger_rms_norm_wrapper(dim, **kwargs):
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
return LigerRMSNorm(hidden_size=dim, **kwargs)
modeling_gemma3.Gemma3RMSNorm = partial(
_liger_rms_norm_wrapper,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
if cfg.liger_glu_activation:
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
if cfg.liger_layer_norm:
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
raise NotImplementedError(
"Fused linear cross entropy is not yet supported for Gemma3."
)
elif cfg.model_config_type == "llama4": elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import ( from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4, apply_liger_kernel_to_llama4,

View File

@@ -31,6 +31,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"starcoder2", "starcoder2",
"deepseek_v2", "deepseek_v2",
"deepseek_v3", "deepseek_v3",
"glm",
"glm4",
] ]

View File

@@ -272,7 +272,7 @@ class ReLoRAScheduler(LRScheduler):
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps
self.anneal_steps = anneal_steps self.anneal_steps = anneal_steps
self.min_lr_scale = min_lr_scale self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose) super().__init__(optimizer, inner_schedule.last_epoch)
def get_lr(self) -> float: def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch self.inner_schedule.last_epoch = self.last_epoch

View File

@@ -3,6 +3,7 @@
import functools import functools
import logging import logging
import os import os
import tempfile
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -117,9 +118,27 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
cfg.pretraining_dataset[0]["type"] or "pretrain", cfg.pretraining_dataset[0]["type"] or "pretrain",
) )
iter_ds = load_dataset( # when letting accelerator dispatch batches from the main process, we don't need to load the dataset from
path, streaming=True, split=split, name=name, data_files=data_files # other ranks, we just need to present a fake dataset
) if (
cfg.accelerator_config
and cfg.accelerator_config.dispatch_batches
and not is_local_main_process()
):
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f:
f.write("text\n")
f.write("lorem ipsum dolor sit amet\n")
# rewind the file pointer to the beginning so we can read it again
f.seek(0)
iter_ds = load_dataset(
"csv", data_files=f.name, split="train", streaming=True
)
else:
if is_local_main_process():
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip: if skip:
LOG.info(f"Skipping {skip} samples from the dataset") LOG.info(f"Skipping {skip} samples from the dataset")
iter_ds = iter_ds.skip(skip) iter_ds = iter_ds.skip(skip)

View File

@@ -40,7 +40,7 @@ class RexLR(LRScheduler):
self.max_lr = max_lr self.max_lr = max_lr
self.total_steps = total_steps self.total_steps = total_steps
self.num_warmup_steps = num_warmup_steps self.num_warmup_steps = num_warmup_steps
self.last_step = last_step - 1 self.last_step = max(last_step - 1, 0)
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming. # Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups: for group in optimizer.param_groups:

View File

@@ -660,6 +660,7 @@ class AxolotlInputConfig(
data.get("val_set_size") == 0 data.get("val_set_size") == 0
and (data.get("eval_steps") or data.get("eval_strategy")) and (data.get("eval_steps") or data.get("eval_strategy"))
and not data.get("test_datasets") and not data.get("test_datasets")
and data.get("eval_strategy") != "no"
): ):
raise ValueError( raise ValueError(
"eval_steps and eval_strategy are not supported with val_set_size == 0" "eval_steps and eval_strategy are not supported with val_set_size == 0"

View File

@@ -36,3 +36,11 @@ class VllmConfig(BaseModel):
default=None, default=None,
json_schema_extra={"description": "Enable prefix caching for VLLM"}, json_schema_extra={"description": "Enable prefix caching for VLLM"},
) )
host: str | None = Field(
default="0.0.0.0", # nosec B104
json_schema_extra={"description": "Host for the vLLM server to start on"},
)
port: int | None = Field(
default=8000,
json_schema_extra={"description": "Port of the vLLM server to start on"},
)

View File

@@ -193,6 +193,14 @@ def download_tiny_shakespeare_dataset():
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset") snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_evolkit_kd_sample_dataset():
# download the dataset
snapshot_download_w_retry(
"axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_deepseek_model_fixture(): def download_deepseek_model_fixture():
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model") snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
@@ -208,6 +216,16 @@ def download_huggyllama_model_fixture():
) )
@pytest.fixture(scope="session", autouse=True)
def download_llama33_70b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_llama_1b_model_fixture(): def download_llama_1b_model_fixture():
# download the tokenizer only # download the tokenizer only
@@ -315,6 +333,14 @@ def download_llama2_model_fixture():
) )
@pytest.fixture(scope="session", autouse=True)
def download_llama32_1b_model_fixture():
snapshot_download_w_retry(
"osllmai-community/Llama-3.2-1B",
repo_type="model",
)
@pytest.fixture @pytest.fixture
@enable_hf_offline @enable_hf_offline
def tokenizer_huggyllama( def tokenizer_huggyllama(

View File

@@ -8,7 +8,7 @@ from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils import get_pytorch_version from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists from ..utils import check_model_output_exists
@@ -56,6 +56,7 @@ class TestCutCrossEntropyIntegration:
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
def test_llama_w_cce(self, min_cfg, temp_dir): def test_llama_w_cce(self, min_cfg, temp_dir):
cfg = DictDefault(min_cfg) cfg = DictDefault(min_cfg)
cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
@@ -101,6 +102,7 @@ class TestCutCrossEntropyIntegration:
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
@@ -129,6 +131,7 @@ class TestCutCrossEntropyIntegration:
attention_type: True, attention_type: True,
} }
) )
cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()

View File

@@ -5,7 +5,7 @@ Simple end-to-end test for Liger integration
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1 from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
@@ -54,6 +54,7 @@ class LigerIntegrationTestCase:
} }
) )
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
@@ -100,6 +101,7 @@ class LigerIntegrationTestCase:
} }
) )
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()

View File

View File

@@ -10,7 +10,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard from ...utils import check_tensorboard
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -0,0 +1,2 @@
# Tests under this directory should get run "solo" on their own as they
# seem to cause issues when run in the same batch as other tests.

View File

@@ -49,8 +49,9 @@ class TestPackedFlex:
}, },
"datasets": [ "datasets": [
{ {
"path": "vicgalle/alpaca-gpt4", "path": "tatsu-lab/alpaca",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
@@ -60,6 +60,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"fp16": True, "fp16": True,
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -104,6 +105,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"fp16": True, "fp16": True,
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
@@ -63,6 +63,7 @@ class TestFalconPatched(unittest.TestCase):
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -103,6 +104,7 @@ class TestFalconPatched(unittest.TestCase):
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
@@ -67,6 +67,7 @@ class TestFusedLlama(unittest.TestCase):
cfg.bf16 = True cfg.bf16 = True
else: else:
cfg.fp16 = True cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -11,7 +11,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
@@ -65,6 +65,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -105,6 +106,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_availab
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
@@ -70,6 +70,7 @@ class TestLoraLlama(unittest.TestCase):
else: else:
cfg.fp16 = True cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -120,6 +121,7 @@ class TestLoraLlama(unittest.TestCase):
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
@@ -63,6 +63,7 @@ class TestMistral(unittest.TestCase):
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -104,6 +105,7 @@ class TestMistral(unittest.TestCase):
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
@@ -60,6 +60,7 @@ class TestMixtral(unittest.TestCase):
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -6,7 +6,7 @@ import unittest
import transformers import transformers
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
@@ -47,6 +47,7 @@ class TestModelPatches(unittest.TestCase):
"eval_steps": 10, "eval_steps": 10,
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False) load_model(cfg, tokenizer, inference=False)
@@ -79,6 +80,7 @@ class TestModelPatches(unittest.TestCase):
"eval_steps": 10, "eval_steps": 10,
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False) load_model(cfg, tokenizer, inference=False)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
@@ -63,6 +63,7 @@ class TestPhiMultipack(unittest.TestCase):
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -82,7 +83,7 @@ class TestPhiMultipack(unittest.TestCase):
"sample_packing": True, "sample_packing": True,
"flash_attention": True, "flash_attention": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"load_in_8bit": False, "load_in_4bit": True,
"adapter": "qlora", "adapter": "qlora",
"lora_r": 64, "lora_r": 64,
"lora_alpha": 32, "lora_alpha": 32,
@@ -114,6 +115,7 @@ class TestPhiMultipack(unittest.TestCase):
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, most_recent_subdir from ..utils import check_model_output_exists, most_recent_subdir
@@ -46,8 +46,9 @@ class TestResumeLlama:
}, },
"datasets": [ "datasets": [
{ {
"path": "vicgalle/alpaca-gpt4", "path": "tatsu-lab/alpaca",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 2, "num_epochs": 2,
@@ -67,6 +68,7 @@ class TestResumeLlama:
cfg.bf16 = True cfg.bf16 = True
else: else:
cfg.fp16 = True cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -10,7 +10,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard from ..utils import check_model_output_exists, check_tensorboard
@@ -72,6 +72,7 @@ class TestUnslothQLoRA:
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -122,6 +123,7 @@ class TestUnslothQLoRA:
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -177,6 +179,7 @@ class TestUnslothQLoRA:
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -41,8 +41,9 @@ class TestPackedFlex(unittest.TestCase):
}, },
"datasets": [ "datasets": [
{ {
"path": "vicgalle/alpaca-gpt4", "path": "tatsu-lab/alpaca",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,

View File

@@ -102,6 +102,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"use_tensorboard": True, "use_tensorboard": True,
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -109,6 +109,7 @@ class TestLlamaVision(unittest.TestCase):
"bf16": True, "bf16": True,
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -40,8 +40,9 @@ class TestPackedLlama(unittest.TestCase):
}, },
"datasets": [ "datasets": [
{ {
"path": "vicgalle/alpaca-gpt4", "path": "tatsu-lab/alpaca",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,

View File

@@ -79,7 +79,7 @@ class TestPhi(unittest.TestCase):
"tokenizer_type": "AutoTokenizer", "tokenizer_type": "AutoTokenizer",
"sequence_len": 2048, "sequence_len": 2048,
"sample_packing": False, "sample_packing": False,
"load_in_8bit": False, "load_in_4bit": True,
"adapter": "qlora", "adapter": "qlora",
"lora_r": 64, "lora_r": 64,
"lora_alpha": 32, "lora_alpha": 32,
@@ -111,6 +111,7 @@ class TestPhi(unittest.TestCase):
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
@@ -57,6 +57,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
"seed": 42, "seed": 42,
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -11,7 +11,7 @@ from unittest.mock import patch
import pytest import pytest
from datasets import Dataset from datasets import Dataset
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets from axolotl.utils.data.utils import deduplicate_and_log_datasets
@@ -319,6 +319,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
"num_epochs": 1, "num_epochs": 1,
} }
) )
self.cfg_1 = validate_config(self.cfg_1)
normalize_config(self.cfg_1) normalize_config(self.cfg_1)
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits") @pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")