Compare commits
13 Commits
merged-255
...
preprocess
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8e92407ff | ||
|
|
c12906134d | ||
|
|
8154d26614 | ||
|
|
fefcbc300d | ||
|
|
7d479348ee | ||
|
|
ce0259db13 | ||
|
|
2798817cf9 | ||
|
|
0e1b081e49 | ||
|
|
8df37ad91f | ||
|
|
9b74298328 | ||
|
|
ae8738aa87 | ||
|
|
ec52561a0c | ||
|
|
eadb16c709 |
12
.github/workflows/base.yml
vendored
12
.github/workflows/base.yml
vendored
@@ -46,18 +46,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
|
||||
12
.github/workflows/main.yml
vendored
12
.github/workflows/main.yml
vendored
@@ -31,11 +31,6 @@ jobs:
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras: vllm
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -98,11 +93,6 @@ jobs:
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -148,7 +138,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
|
||||
8
.github/workflows/multi-gpu-e2e.yml
vendored
8
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -45,13 +45,6 @@ jobs:
|
||||
axolotl_extras: vllm
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
steps:
|
||||
@@ -74,7 +67,6 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.multigpu
|
||||
|
||||
1
.github/workflows/tests-nightly.yml
vendored
1
.github/workflows/tests-nightly.yml
vendored
@@ -147,7 +147,6 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
11
.github/workflows/tests.yml
vendored
11
.github/workflows/tests.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.4.1", "2.5.1", "2.6.0", "2.7.0"]
|
||||
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -109,7 +109,6 @@ jobs:
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
files: ./coverage.xml
|
||||
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
||||
fail_ci_if_error: false
|
||||
@@ -242,7 +241,6 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
@@ -270,12 +268,6 @@ jobs:
|
||||
pytorch: 2.5.1
|
||||
num_gpus: 1
|
||||
axolotl_extras: vllm
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -296,7 +288,6 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
12
cicd/cicd.sh
12
cicd/cicd.sh
@@ -9,7 +9,8 @@ pytest -v --durations=10 -n8 \
|
||||
--ignore=tests/patched/ \
|
||||
--ignore=tests/cli \
|
||||
/workspace/axolotl/tests/ \
|
||||
--cov=axolotl
|
||||
--cov=axolotl \
|
||||
--cov-report=xml:coverage.xml
|
||||
|
||||
# Run lora kernels tests with coverage append
|
||||
pytest -v --durations=10 \
|
||||
@@ -50,6 +51,11 @@ pytest -v --durations=10 \
|
||||
/workspace/axolotl/tests/e2e/ \
|
||||
--cov=axolotl \
|
||||
--cov-append \
|
||||
--cov-report=xml:e2e-coverage.xml
|
||||
--cov-report=xml:coverage.xml
|
||||
|
||||
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}
|
||||
# Upload coverage to Codecov
|
||||
if [ -f e2e-coverage.xml ]; then
|
||||
codecov -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}
|
||||
else
|
||||
echo "Coverage file not found. Coverage report may have failed."
|
||||
fi
|
||||
|
||||
@@ -28,7 +28,6 @@ df_args = {
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||
}
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ df_args = {
|
||||
"CUDA": os.environ.get("CUDA", "121"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||
}
|
||||
|
||||
|
||||
@@ -1,23 +1,25 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# only run one test at a time so as not to OOM the GPU
|
||||
pytest -v --durations=10 -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
|
||||
|
||||
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
||||
pytest -v -n2 \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
|
||||
/workspace/axolotl/tests/e2e/multigpu/ \
|
||||
--cov=axolotl
|
||||
|
||||
# Run solo tests with coverage append
|
||||
pytest -v --durations=10 -n1 \
|
||||
/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
||||
--cov=axolotl \
|
||||
--cov-append
|
||||
--cov-report=xml:multigpu-coverage.xml
|
||||
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/ \
|
||||
--cov=axolotl \
|
||||
--cov-append \
|
||||
--cov-report=xml:multigpu-coverage.xml
|
||||
|
||||
# Upload coverage to Codecov
|
||||
codecov upload-process -t $CODECOV_TOKEN -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}
|
||||
if [ -f multigpu-coverage.xml ]; then
|
||||
codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}
|
||||
else
|
||||
echo "Coverage file not found. Coverage report may have failed."
|
||||
fi
|
||||
|
||||
@@ -49,6 +49,3 @@ comment:
|
||||
require_changes: no
|
||||
require_base: no
|
||||
require_head: yes
|
||||
|
||||
github_checks:
|
||||
annotations: false
|
||||
|
||||
@@ -37,7 +37,3 @@ RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
|
||||
pip3 install flash-attn==2.7.4.post1; \
|
||||
fi
|
||||
|
||||
11
docs/cli.qmd
11
docs/cli.qmd
@@ -199,17 +199,6 @@ output_dir: # Directory to save evaluation results
|
||||
|
||||
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
|
||||
|
||||
### delinearize-llama4
|
||||
|
||||
Delinearizes a Llama 4 linearized model into a regular HuggingFace Llama 4 model. This only works with the non-quantized linearized model.
|
||||
|
||||
```bash
|
||||
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
|
||||
```
|
||||
|
||||
This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.
|
||||
|
||||
|
||||
## Legacy CLI Usage
|
||||
|
||||
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
|
||||
|
||||
@@ -19,12 +19,6 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
||||
|
||||
## Installation Methods {#sec-installation-methods}
|
||||
|
||||
::: {.callout-important}
|
||||
Please make sure to have Pytorch installed before installing Axolotl in your local environment.
|
||||
|
||||
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
|
||||
:::
|
||||
|
||||
### PyPI Installation (Recommended) {#sec-pypi}
|
||||
|
||||
```{.bash}
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
base_model: THUDM/GLM-4-32B-0414
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -26,11 +26,3 @@ Multi-GPU (4xH100) for Llama 4 Scout uses 62.8GB VRAM/GPU @ 4k contenxt length @
|
||||
### Llama 4 Maverick 17Bx128Experts (400B)
|
||||
|
||||
Coming Soon
|
||||
|
||||
## Delinearized Llama 4 Models
|
||||
|
||||
We provide a script to delinearize Llama 4 linearized models into regular HuggingFace Llama 4 models.
|
||||
|
||||
```bash
|
||||
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
|
||||
```
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
codecov
|
||||
codecov-cli
|
||||
pytest
|
||||
pytest-cov
|
||||
pytest-retry
|
||||
|
||||
@@ -6,7 +6,7 @@ triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.5.8
|
||||
liger-kernel==0.5.6
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
@@ -19,7 +19,6 @@ datasets==3.5.0
|
||||
deepspeed>=0.15.4
|
||||
trl==0.16.1
|
||||
hf_xet==1.0.0
|
||||
hqq==0.2.5
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
|
||||
12
setup.py
12
setup.py
@@ -51,7 +51,7 @@ def parse_requirements(extras_require_map):
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
except PackageNotFoundError:
|
||||
torch_version = "2.6.0" # default to torch 2.6
|
||||
torch_version = "2.5.1"
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
@@ -64,15 +64,9 @@ def parse_requirements(extras_require_map):
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 7):
|
||||
if (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
|
||||
extras_require_map["vllm"] = ["vllm==0.8.3"]
|
||||
elif (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append(
|
||||
"xformers==0.0.29.post2"
|
||||
) # vllm needs post2 w torch 2.6
|
||||
_install_requires.append("xformers==0.0.29.post2")
|
||||
extras_require_map["vllm"] = ["vllm==0.8.3"]
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
|
||||
@@ -39,16 +39,16 @@ class TrainerCliArgs:
|
||||
class VllmServeCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
|
||||
|
||||
tensor_parallel_size: Optional[int] = field(
|
||||
default=None,
|
||||
tensor_parallel_size: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of tensor parallel workers to use."},
|
||||
)
|
||||
host: Optional[str] = field(
|
||||
default=None, # nosec B104
|
||||
host: str = field(
|
||||
default="0.0.0.0", # nosec B104
|
||||
metadata={"help": "Host address to run the server on."},
|
||||
)
|
||||
port: Optional[int] = field(
|
||||
default=None,
|
||||
port: int = field(
|
||||
default=8000,
|
||||
metadata={"help": "Port to run the server on."},
|
||||
)
|
||||
gpu_memory_utilization: Optional[float] = field(
|
||||
|
||||
@@ -129,17 +129,19 @@ def load_preference_datasets(
|
||||
total_num_steps = None
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
if not cfg.rl == "grpo":
|
||||
LOG.info("check_dataset_labels...")
|
||||
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
|
||||
check_dataset_labels(
|
||||
train_samples,
|
||||
tokenizer,
|
||||
num_examples=cli_args.debug_num_examples,
|
||||
text_only=cli_args.debug_text_only,
|
||||
rl_mode=True,
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
|
||||
|
||||
check_dataset_labels(
|
||||
train_samples,
|
||||
tokenizer,
|
||||
num_examples=cli_args.debug_num_examples,
|
||||
text_only=cli_args.debug_text_only,
|
||||
rl_mode=True,
|
||||
)
|
||||
|
||||
return TrainDatasetMeta(
|
||||
train_dataset=train_dataset,
|
||||
|
||||
@@ -1040,11 +1040,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
if self.cfg.trl and self.cfg.trl.beta is not None:
|
||||
training_args_kwargs["beta"] = self.cfg.trl.beta
|
||||
elif self.cfg.rl_beta is not None:
|
||||
training_args_kwargs["beta"] = self.cfg.rl_beta
|
||||
elif self.cfg.orpo_alpha is not None:
|
||||
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
|
||||
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
|
||||
if self.cfg.orpo_alpha:
|
||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||
|
||||
|
||||
@@ -40,8 +40,8 @@ class GRPOStrategy:
|
||||
|
||||
if trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host
|
||||
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port
|
||||
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
|
||||
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
|
||||
if trl.vllm_server_timeout:
|
||||
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
|
||||
if trl.vllm_guided_decoding_regex:
|
||||
|
||||
@@ -47,8 +47,6 @@ cut_cross_entropy: true
|
||||
- qwen2
|
||||
- cohere
|
||||
- cohere2
|
||||
- glm
|
||||
- glm4
|
||||
|
||||
## Citation
|
||||
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
"""GLM 4 patch. GLM family inherits from Llama."""
|
||||
|
||||
from types import MethodType
|
||||
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
)
|
||||
|
||||
|
||||
def patch_glm(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
|
||||
# Set the _PATCH_OPTS in the llama patch file
|
||||
import cut_cross_entropy.transformers.llama as llama_patch
|
||||
|
||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||
|
||||
from cut_cross_entropy.transformers.llama import cce_forward
|
||||
from transformers.models.glm import modeling_glm
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_glm.GlmForCausalLM
|
||||
), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_glm.GlmForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_glm4(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
|
||||
# Set the _PATCH_OPTS in the llama patch file
|
||||
import cut_cross_entropy.transformers.llama as llama_patch
|
||||
|
||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||
|
||||
from cut_cross_entropy.transformers.llama import cce_forward
|
||||
from transformers.models.glm4 import modeling_glm4
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_glm4.Glm4ForCausalLM
|
||||
), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_glm4.Glm4ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -20,10 +20,6 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
|
||||
patch_gemma3,
|
||||
patch_gemma3_text,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import (
|
||||
patch_glm,
|
||||
patch_glm4,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
|
||||
patch_llama4,
|
||||
patch_llama4_text,
|
||||
@@ -49,8 +45,6 @@ CUT_CROSS_ENTROPY_MODEL_MAPPING = {
|
||||
"qwen2": patch_qwen2,
|
||||
"cohere": patch_cohere,
|
||||
"cohere2": patch_cohere2,
|
||||
"glm": patch_glm,
|
||||
"glm4": patch_glm4,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ liger_fused_linear_cross_entropy: true
|
||||
- deepseek_v2
|
||||
- gemma
|
||||
- gemma2
|
||||
- gemma3
|
||||
- gemma3 (partial support, no support for FLCE yet)
|
||||
- granite
|
||||
- jamba
|
||||
- llama
|
||||
|
||||
@@ -21,6 +21,7 @@ It is designed to be performant, correct, and light-weight.
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
@@ -54,6 +55,7 @@ class LigerPlugin(BasePlugin):
|
||||
)
|
||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
@@ -139,6 +141,38 @@ class LigerPlugin(BasePlugin):
|
||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
|
||||
from transformers.models.gemma3 import modeling_gemma3
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
|
||||
def _liger_rms_norm_wrapper(dim, **kwargs):
|
||||
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
|
||||
return LigerRMSNorm(hidden_size=dim, **kwargs)
|
||||
|
||||
modeling_gemma3.Gemma3RMSNorm = partial(
|
||||
_liger_rms_norm_wrapper,
|
||||
offset=1.0,
|
||||
casting_mode="gemma",
|
||||
init_fn="zeros",
|
||||
in_place=False,
|
||||
)
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
|
||||
if cfg.liger_layer_norm:
|
||||
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
|
||||
|
||||
if cfg.liger_cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
raise NotImplementedError(
|
||||
"Fused linear cross entropy is not yet supported for Gemma3."
|
||||
)
|
||||
elif cfg.model_config_type == "llama4":
|
||||
from axolotl.integrations.liger.models.llama4 import (
|
||||
apply_liger_kernel_to_llama4,
|
||||
|
||||
@@ -31,8 +31,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"starcoder2",
|
||||
"deepseek_v2",
|
||||
"deepseek_v3",
|
||||
"glm",
|
||||
"glm4",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -272,7 +272,7 @@ class ReLoRAScheduler(LRScheduler):
|
||||
self.warmup_steps = warmup_steps
|
||||
self.anneal_steps = anneal_steps
|
||||
self.min_lr_scale = min_lr_scale
|
||||
super().__init__(optimizer, inner_schedule.last_epoch)
|
||||
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
||||
|
||||
def get_lr(self) -> float:
|
||||
self.inner_schedule.last_epoch = self.last_epoch
|
||||
|
||||
@@ -4,30 +4,73 @@ module for base dataset transform strategies
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def import_from_path(module_name: str, file_path: str):
|
||||
"""
|
||||
Import a module from a file path.
|
||||
|
||||
Args:
|
||||
module_name: Name of the module.
|
||||
file_path: Path to the file.
|
||||
|
||||
Returns:
|
||||
module: The imported module.
|
||||
"""
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
if spec is None:
|
||||
raise ImportError(f"Could not create module spec for: {file_path}")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
|
||||
sys.modules[module_name] = module
|
||||
loader = importlib.machinery.SourceFileLoader(module_name, file_path)
|
||||
spec.loader = loader
|
||||
loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def load(strategy, cfg, module_base=None, **kwargs):
|
||||
try:
|
||||
if len(strategy.split(".")) == 1:
|
||||
strategy = strategy + ".default"
|
||||
load_fn = strategy.split(".")[-1]
|
||||
if len(strategy.split(".")) > 1:
|
||||
try:
|
||||
importlib.import_module(
|
||||
strategy.split(".")[-2],
|
||||
".".join(strategy.split(".")[:-2]),
|
||||
)
|
||||
module_base = ".".join(strategy.split(".")[:-2])
|
||||
strategy = strategy.split(".")[-2]
|
||||
except ModuleNotFoundError:
|
||||
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||
else:
|
||||
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||
if len(strategy.split(".")) == 1:
|
||||
strategy = strategy + ".default"
|
||||
load_fn = strategy.split(".")[-1]
|
||||
func = None
|
||||
if len(strategy.split(".")) > 1:
|
||||
try:
|
||||
mod = importlib.import_module(
|
||||
strategy.split(".")[-2],
|
||||
".".join(strategy.split(".")[:-2]),
|
||||
)
|
||||
func = getattr(mod, load_fn)
|
||||
return func(cfg, **kwargs)
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
try:
|
||||
mod = importlib.import_module(
|
||||
"." + ".".join(strategy.split(".")[:-1]), module_base
|
||||
)
|
||||
func = getattr(mod, load_fn)
|
||||
return func(cfg, **kwargs)
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
try:
|
||||
file_path = "/".join(strategy.split(".")[:-1]) + ".py"
|
||||
module_name = strategy.split(".")[-2]
|
||||
mod = import_from_path(module_name, file_path)
|
||||
func = getattr(mod, load_fn)
|
||||
if func is not None:
|
||||
return func(cfg, **kwargs)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
else:
|
||||
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(strategy, module_base)
|
||||
func = getattr(mod, load_fn)
|
||||
return func(cfg, **kwargs)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"unable to load strategy {strategy}")
|
||||
return None
|
||||
|
||||
LOG.warning(f"unable to load strategy {strategy}")
|
||||
return func
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@@ -118,27 +117,9 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
||||
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
||||
)
|
||||
|
||||
# when letting accelerator dispatch batches from the main process, we don't need to load the dataset from
|
||||
# other ranks, we just need to present a fake dataset
|
||||
if (
|
||||
cfg.accelerator_config
|
||||
and cfg.accelerator_config.dispatch_batches
|
||||
and not is_local_main_process()
|
||||
):
|
||||
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f:
|
||||
f.write("text\n")
|
||||
f.write("lorem ipsum dolor sit amet\n")
|
||||
# rewind the file pointer to the beginning so we can read it again
|
||||
f.seek(0)
|
||||
iter_ds = load_dataset(
|
||||
"csv", data_files=f.name, split="train", streaming=True
|
||||
)
|
||||
else:
|
||||
if is_local_main_process():
|
||||
iter_ds = load_dataset(
|
||||
path, streaming=True, split=split, name=name, data_files=data_files
|
||||
)
|
||||
|
||||
iter_ds = load_dataset(
|
||||
path, streaming=True, split=split, name=name, data_files=data_files
|
||||
)
|
||||
if skip:
|
||||
LOG.info(f"Skipping {skip} samples from the dataset")
|
||||
iter_ds = iter_ds.skip(skip)
|
||||
|
||||
@@ -40,7 +40,7 @@ class RexLR(LRScheduler):
|
||||
self.max_lr = max_lr
|
||||
self.total_steps = total_steps
|
||||
self.num_warmup_steps = num_warmup_steps
|
||||
self.last_step = max(last_step - 1, 0)
|
||||
self.last_step = last_step - 1
|
||||
|
||||
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
|
||||
for group in optimizer.param_groups:
|
||||
|
||||
@@ -660,7 +660,6 @@ class AxolotlInputConfig(
|
||||
data.get("val_set_size") == 0
|
||||
and (data.get("eval_steps") or data.get("eval_strategy"))
|
||||
and not data.get("test_datasets")
|
||||
and data.get("eval_strategy") != "no"
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_steps and eval_strategy are not supported with val_set_size == 0"
|
||||
|
||||
@@ -36,11 +36,3 @@ class VllmConfig(BaseModel):
|
||||
default=None,
|
||||
json_schema_extra={"description": "Enable prefix caching for VLLM"},
|
||||
)
|
||||
host: str | None = Field(
|
||||
default="0.0.0.0", # nosec B104
|
||||
json_schema_extra={"description": "Host for the vLLM server to start on"},
|
||||
)
|
||||
port: int | None = Field(
|
||||
default=8000,
|
||||
json_schema_extra={"description": "Port of the vLLM server to start on"},
|
||||
)
|
||||
|
||||
@@ -193,14 +193,6 @@ def download_tiny_shakespeare_dataset():
|
||||
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_evolkit_kd_sample_dataset():
|
||||
# download the dataset
|
||||
snapshot_download_w_retry(
|
||||
"axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", repo_type="dataset"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_deepseek_model_fixture():
|
||||
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
|
||||
@@ -216,16 +208,6 @@ def download_huggyllama_model_fixture():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_llama33_70b_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
"axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
|
||||
repo_type="model",
|
||||
allow_patterns=["*token*", "config.json"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_llama_1b_model_fixture():
|
||||
# download the tokenizer only
|
||||
@@ -333,14 +315,6 @@ def download_llama2_model_fixture():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_llama32_1b_model_fixture():
|
||||
snapshot_download_w_retry(
|
||||
"osllmai-community/Llama-3.2-1B",
|
||||
repo_type="model",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@enable_hf_offline
|
||||
def tokenizer_huggyllama(
|
||||
|
||||
@@ -8,7 +8,7 @@ from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import get_pytorch_version
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists
|
||||
@@ -56,7 +56,6 @@ class TestCutCrossEntropyIntegration:
|
||||
# pylint: disable=redefined-outer-name
|
||||
def test_llama_w_cce(self, min_cfg, temp_dir):
|
||||
cfg = DictDefault(min_cfg)
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
@@ -102,7 +101,6 @@ class TestCutCrossEntropyIntegration:
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
@@ -131,7 +129,6 @@ class TestCutCrossEntropyIntegration:
|
||||
attention_type: True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
|
||||
@@ -5,7 +5,7 @@ Simple end-to-end test for Liger integration
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
|
||||
@@ -54,7 +54,6 @@ class LigerIntegrationTestCase:
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
@@ -101,7 +100,6 @@ class LigerIntegrationTestCase:
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
# Tests under this directory should get run "solo" on their own as they
|
||||
# seem to cause issues when run in the same batch as other tests.
|
||||
|
||||
@@ -49,9 +49,8 @@ class TestPackedFlex:
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"path": "vicgalle/alpaca-gpt4",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
|
||||
@@ -10,7 +10,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ...utils import check_tensorboard
|
||||
from ..utils import check_tensorboard
|
||||
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
@@ -60,7 +60,6 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
"fp16": True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -105,7 +104,6 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
"fp16": True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
@@ -63,7 +63,6 @@ class TestFalconPatched(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -104,7 +103,6 @@ class TestFalconPatched(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
@@ -67,7 +67,6 @@ class TestFusedLlama(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -11,7 +11,7 @@ import pytest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
@@ -65,7 +65,6 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -106,7 +105,6 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_availab
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
@@ -70,7 +70,6 @@ class TestLoraLlama(unittest.TestCase):
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -121,7 +120,6 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
@@ -63,7 +63,6 @@ class TestMistral(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -105,7 +104,6 @@ class TestMistral(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
@@ -60,7 +60,6 @@ class TestMixtral(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -6,7 +6,7 @@ import unittest
|
||||
|
||||
import transformers
|
||||
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
@@ -47,7 +47,6 @@ class TestModelPatches(unittest.TestCase):
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer, inference=False)
|
||||
@@ -80,7 +79,6 @@ class TestModelPatches(unittest.TestCase):
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer, inference=False)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
@@ -63,7 +63,6 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -83,7 +82,7 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
"sample_packing": True,
|
||||
"flash_attention": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"load_in_4bit": True,
|
||||
"load_in_8bit": False,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 64,
|
||||
"lora_alpha": 32,
|
||||
@@ -115,7 +114,6 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, most_recent_subdir
|
||||
@@ -46,9 +46,8 @@ class TestResumeLlama:
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"path": "vicgalle/alpaca-gpt4",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
@@ -68,7 +67,6 @@ class TestResumeLlama:
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -10,7 +10,7 @@ import pytest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, check_tensorboard
|
||||
@@ -72,7 +72,6 @@ class TestUnslothQLoRA:
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -123,7 +122,6 @@ class TestUnslothQLoRA:
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -179,7 +177,6 @@ class TestUnslothQLoRA:
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -41,9 +41,8 @@ class TestPackedFlex(unittest.TestCase):
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"path": "vicgalle/alpaca-gpt4",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
|
||||
85
tests/e2e/solo/test_preprocess.py
Normal file
85
tests/e2e/solo/test_preprocess.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
E2E tests for preprocessing
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import transformers
|
||||
|
||||
from axolotl.cli.args import PreprocessCliArgs
|
||||
from axolotl.common.datasets import load_preference_datasets
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestCustomRewardFunctionLoading(unittest.TestCase):
|
||||
"""
|
||||
Test case for GRPO training using single GPU
|
||||
"""
|
||||
|
||||
def _utils_write_rewards(self):
|
||||
# write cfg to yaml file
|
||||
with open("rewards.py", "w", encoding="utf-8") as fout:
|
||||
fout.write(
|
||||
"""import random
|
||||
def rand_reward_func(completions, **kwargs) -> list[float]:
|
||||
return [random.uniform(0, 1) for _ in completions]
|
||||
|
||||
def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
def transform_fn(example, tokenizer=None):
|
||||
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
||||
return {
|
||||
"prompt": [{"role": "user", "content": example["question"]},],
|
||||
"answer": label,
|
||||
}
|
||||
return transform_fn, {"remove_columns": ["question"]}
|
||||
"""
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_custom_rewards_fn_preprocess(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"strict": False,
|
||||
"rl": "grpo",
|
||||
"trl": {
|
||||
"beta": 0.001,
|
||||
"max_completion_length": 256,
|
||||
"use_vllm": True,
|
||||
"num_generations": 4,
|
||||
"reward_funcs": [
|
||||
"rewards.rand_reward_func"
|
||||
], # format: '{file_name}.{fn_name}'
|
||||
"reward_weights": [1.0],
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "openai/gsm8k",
|
||||
"name": "main",
|
||||
"type": "rewards.oai_gsm8k_transform",
|
||||
},
|
||||
],
|
||||
"dataset_prepared_path": temp_dir,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"micro_batch_size": 1,
|
||||
"learning_rate": 0.000005,
|
||||
}
|
||||
)
|
||||
|
||||
self._utils_write_rewards()
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
parser = transformers.HfArgumentParser(PreprocessCliArgs)
|
||||
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
|
||||
load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -102,7 +102,6 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -109,7 +109,6 @@ class TestLlamaVision(unittest.TestCase):
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -40,9 +40,8 @@ class TestPackedLlama(unittest.TestCase):
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"path": "vicgalle/alpaca-gpt4",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
|
||||
@@ -79,7 +79,7 @@ class TestPhi(unittest.TestCase):
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": False,
|
||||
"load_in_4bit": True,
|
||||
"load_in_8bit": False,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 64,
|
||||
"lora_alpha": 32,
|
||||
@@ -111,7 +111,6 @@ class TestPhi(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
@@ -57,7 +57,6 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -11,7 +11,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||
@@ -319,7 +319,6 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
"num_epochs": 1,
|
||||
}
|
||||
)
|
||||
self.cfg_1 = validate_config(self.cfg_1)
|
||||
normalize_config(self.cfg_1)
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||
|
||||
Reference in New Issue
Block a user