Compare commits
15 Commits
tensor-par
...
fp8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8836986a92 | ||
|
|
105d0b350b | ||
|
|
f544ab2bed | ||
|
|
641e6f7e51 | ||
|
|
6dc68a653f | ||
|
|
7de6a5639c | ||
|
|
c74f045ba7 | ||
|
|
0402d19759 | ||
|
|
b2430ce670 | ||
|
|
4c834bf25d | ||
|
|
8056ecd30e | ||
|
|
738a057674 | ||
|
|
cdc71f73c8 | ||
|
|
6459ac7357 | ||
|
|
964d858da0 |
@@ -74,6 +74,7 @@ Features:
|
|||||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
|
||||||
|
|
||||||
## Quickstart ⚡
|
## Quickstart ⚡
|
||||||
@@ -96,6 +97,10 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
|||||||
# inference
|
# inference
|
||||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||||
--lora_model_dir="./lora-out"
|
--lora_model_dir="./lora-out"
|
||||||
|
|
||||||
|
# gradio
|
||||||
|
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||||
|
--lora_model_dir="./lora-out" --gradio
|
||||||
```
|
```
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
@@ -918,6 +923,10 @@ Pass the appropriate flag to the train command:
|
|||||||
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
|
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
|
||||||
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
||||||
```
|
```
|
||||||
|
-- With gradio hosting
|
||||||
|
```bash
|
||||||
|
python -m axolotl.cli.inference examples/your_config.yml --gradio
|
||||||
|
```
|
||||||
|
|
||||||
Please use `--sample_packing False` if you have it on and receive the error similar to below:
|
Please use `--sample_packing False` if you have it on and receive the error similar to below:
|
||||||
|
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ WORKDIR /workspace/axolotl
|
|||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
|
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[flash-attn]; \
|
pip install -e .[deepspeed,flash-attn]; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works
|
# fix so that git fetch/pull from remote works
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ ENV PATH="/root/miniconda3/bin:${PATH}"
|
|||||||
ARG PYTHON_VERSION="3.9"
|
ARG PYTHON_VERSION="3.9"
|
||||||
ARG PYTORCH_VERSION="2.0.1"
|
ARG PYTORCH_VERSION="2.0.1"
|
||||||
ARG CUDA="118"
|
ARG CUDA="118"
|
||||||
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
|
|
||||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||||
|
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||||
|
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
||||||
@@ -27,47 +29,9 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
|||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} deepspeed-kernels --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||||
|
|
||||||
FROM base-builder AS deepspeed-builder
|
RUN git lfs install --skip-repo && \
|
||||||
|
pip3 install awscli && \
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
|
||||||
|
|
||||||
WORKDIR /workspace
|
|
||||||
|
|
||||||
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
|
|
||||||
cd DeepSpeed && \
|
|
||||||
MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 DS_BUILD_EVOFORMER_ATTN=0 python3 setup.py bdist_wheel
|
|
||||||
|
|
||||||
FROM base-builder AS bnb-builder
|
|
||||||
|
|
||||||
WORKDIR /workspace
|
|
||||||
ARG CUDA="118"
|
|
||||||
ENV CUDA=$CUDA
|
|
||||||
ARG MAX_JOBS="-1"
|
|
||||||
ENV MAX_JOBS=$MAX_JOBS
|
|
||||||
|
|
||||||
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
|
|
||||||
cd bitsandbytes && \
|
|
||||||
CUDA_VERSION=$CUDA make cuda11x && \
|
|
||||||
python setup.py bdist_wheel
|
|
||||||
|
|
||||||
FROM base-builder
|
|
||||||
|
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
|
||||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
|
||||||
|
|
||||||
RUN mkdir -p /workspace/builds
|
|
||||||
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
|
||||||
|
|
||||||
RUN mkdir -p /workspace/wheels/bitsandbytes
|
|
||||||
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
|
|
||||||
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
|
|
||||||
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
|
|
||||||
|
|
||||||
RUN pip3 install wheels/deepspeed-*.whl
|
|
||||||
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
|
|
||||||
RUN git lfs install --skip-repo
|
|
||||||
RUN pip3 install awscli && \
|
|
||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_prepared_run
|
dataset_prepared_path: last_prepared_run
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
|
|
||||||
adapter:
|
adapter:
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca:chat
|
type: alpaca:chat
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ datasets:
|
|||||||
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
||||||
type: "alpaca:chat"
|
type: "alpaca:chat"
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
# enable QLoRA
|
# enable QLoRA
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca:chat
|
type: alpaca:chat
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
adapter:
|
adapter:
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./relora-out
|
output_dir: ./relora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 1024
|
sequence_len: 1024
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ datasets:
|
|||||||
- openassistant_best_replies_train.jsonl
|
- openassistant_best_replies_train.jsonl
|
||||||
type: "completion"
|
type: "completion"
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
# enable QLoRA
|
# enable QLoRA
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
# Page
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
# Table of contents
|
|
||||||
|
|
||||||
* [Page](README.md)
|
|
||||||
* [Small dev details](small-dev-details.md)
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# Small dev details
|
|
||||||
|
|
||||||
/
|
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
torch==2.0.1
|
torch==2.0.1
|
||||||
auto-gptq
|
auto-gptq==0.4.2
|
||||||
packaging
|
packaging
|
||||||
peft @ git+https://github.com/huggingface/peft.git
|
peft==0.6.0
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697
|
transformers @ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697
|
||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.41.1
|
||||||
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
|
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
|
||||||
@@ -17,7 +17,7 @@ sentencepiece
|
|||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
xformers>=0.0.22
|
xformers>=0.0.22
|
||||||
optimum
|
optimum==1.13.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
colorama
|
colorama
|
||||||
numba
|
numba
|
||||||
@@ -31,4 +31,4 @@ scikit-learn==1.2.2
|
|||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
fschat==0.2.29
|
fschat==0.2.29
|
||||||
tensor_parallel
|
gradio
|
||||||
|
|||||||
@@ -6,8 +6,10 @@ import os
|
|||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from threading import Thread
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@@ -16,7 +18,7 @@ from accelerate.commands.config import config_args
|
|||||||
from art import text2art
|
from art import text2art
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
from transformers import GenerationConfig, TextStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
@@ -153,6 +155,91 @@ def do_inference(
|
|||||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def do_inference_gradio(
|
||||||
|
*,
|
||||||
|
cfg: DictDefault,
|
||||||
|
cli_args: TrainerCliArgs,
|
||||||
|
):
|
||||||
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||||
|
prompter = cli_args.prompter
|
||||||
|
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||||
|
|
||||||
|
for token, symbol in default_tokens.items():
|
||||||
|
# If the token isn't already specified in the config, add it
|
||||||
|
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||||
|
tokenizer.add_special_tokens({token: symbol})
|
||||||
|
|
||||||
|
prompter_module = None
|
||||||
|
if prompter:
|
||||||
|
prompter_module = getattr(
|
||||||
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.landmark_attention:
|
||||||
|
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||||
|
|
||||||
|
set_model_mem_id(model, tokenizer)
|
||||||
|
model.set_mem_cache_args(
|
||||||
|
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||||
|
)
|
||||||
|
|
||||||
|
model = model.to(cfg.device)
|
||||||
|
|
||||||
|
def generate(instruction):
|
||||||
|
if not instruction:
|
||||||
|
return
|
||||||
|
if prompter_module:
|
||||||
|
# pylint: disable=stop-iteration-return
|
||||||
|
prompt: str = next(
|
||||||
|
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = instruction.strip()
|
||||||
|
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
repetition_penalty=1.1,
|
||||||
|
max_new_tokens=1024,
|
||||||
|
temperature=0.9,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=40,
|
||||||
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
do_sample=True,
|
||||||
|
use_cache=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
output_scores=False,
|
||||||
|
)
|
||||||
|
streamer = TextIteratorStreamer(tokenizer)
|
||||||
|
generation_kwargs = {
|
||||||
|
"inputs": batch["input_ids"].to(cfg.device),
|
||||||
|
"generation_config": generation_config,
|
||||||
|
"streamer": streamer,
|
||||||
|
}
|
||||||
|
|
||||||
|
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
all_text = ""
|
||||||
|
|
||||||
|
for new_text in streamer:
|
||||||
|
all_text += new_text
|
||||||
|
yield all_text
|
||||||
|
|
||||||
|
demo = gr.Interface(
|
||||||
|
fn=generate,
|
||||||
|
inputs="textbox",
|
||||||
|
outputs="text",
|
||||||
|
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||||
|
)
|
||||||
|
demo.queue().launch(show_api=False, share=True)
|
||||||
|
|
||||||
|
|
||||||
def choose_config(path: Path):
|
def choose_config(path: Path):
|
||||||
yaml_files = list(path.glob("*.yml"))
|
yaml_files = list(path.glob("*.yml"))
|
||||||
|
|
||||||
|
|||||||
@@ -6,11 +6,16 @@ from pathlib import Path
|
|||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art
|
from axolotl.cli import (
|
||||||
|
do_inference,
|
||||||
|
do_inference_gradio,
|
||||||
|
load_cfg,
|
||||||
|
print_axolotl_text_art,
|
||||||
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
@@ -21,7 +26,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
)
|
)
|
||||||
parsed_cli_args.inference = True
|
parsed_cli_args.inference = True
|
||||||
|
|
||||||
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
if gradio:
|
||||||
|
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
else:
|
||||||
|
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import abc
|
|||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@@ -14,14 +13,13 @@ from functools import partial
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import tensor_parallel as tp
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||||
from transformers.trainer_pt_utils import SequentialDistributedSampler
|
from transformers.trainer_utils import seed_worker
|
||||||
|
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
@@ -32,9 +30,9 @@ from axolotl.utils.callbacks import (
|
|||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
|
||||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||||
from axolotl.utils.distributed import is_distributed
|
from axolotl.utils.samplers import MultipackBatchSampler
|
||||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -104,8 +102,9 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
bench_source_max_len: int = field(
|
bench_source_max_len: int = field(
|
||||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||||
)
|
)
|
||||||
tensor_parallel: bool = field(
|
dataloader_prefetch_factor: Optional[int] = field(
|
||||||
default=False, metadata={"help": "Use tensor parallelism to train"}
|
default=None,
|
||||||
|
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -150,46 +149,69 @@ class AxolotlTrainer(Trainer):
|
|||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
if self.args.world_size > 1 and self.args.sample_packing:
|
if self.args.sample_packing:
|
||||||
return DistributedSampler(
|
return MultipackBatchSampler(
|
||||||
self.train_dataset,
|
RandomSampler(self.train_dataset),
|
||||||
num_replicas=self.args.world_size,
|
self.args.train_batch_size,
|
||||||
rank=self.args.process_index,
|
drop_last=True,
|
||||||
seed=self.args.seed,
|
batch_max_len=self._train_batch_size * self.args.max_seq_length,
|
||||||
|
lengths=(
|
||||||
|
self.train_dataset.data.column("position_ids")
|
||||||
|
.to_pandas()
|
||||||
|
.apply(lambda x: x[-1] + 1)
|
||||||
|
.values
|
||||||
|
),
|
||||||
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
)
|
)
|
||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
def _get_eval_sampler(
|
def _get_eval_sampler(
|
||||||
self, eval_dataset: Dataset
|
self, eval_dataset: Dataset
|
||||||
) -> Optional[torch.utils.data.Sampler]:
|
) -> Optional[torch.utils.data.Sampler]:
|
||||||
if (
|
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||||
self.args.world_size > 1
|
return MultipackBatchSampler(
|
||||||
and self.args.sample_packing
|
SequentialSampler(eval_dataset),
|
||||||
and self.args.eval_sample_packing is not False
|
self.args.per_device_eval_batch_size,
|
||||||
):
|
drop_last=True,
|
||||||
return SequentialDistributedSampler(
|
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
|
||||||
eval_dataset,
|
lengths=(
|
||||||
num_replicas=self.args.world_size,
|
eval_dataset.data.column("position_ids")
|
||||||
rank=self.args.process_index,
|
.to_pandas()
|
||||||
batch_size=self.args.per_device_eval_batch_size,
|
.apply(lambda x: x[-1] + 1)
|
||||||
|
.values
|
||||||
|
),
|
||||||
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
)
|
)
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
if self.args.sample_packing:
|
if self.args.sample_packing:
|
||||||
train_sampler = self._get_train_sampler()
|
train_dataset = self.train_dataset
|
||||||
return self.accelerator.prepare(
|
train_dataset = train_dataset.remove_columns(["length"])
|
||||||
MultipackDistributedDataloader(
|
data_collator = self.data_collator
|
||||||
self.train_dataset,
|
dataloader_params = {
|
||||||
batch_size=self._train_batch_size,
|
"batch_size": self._train_batch_size,
|
||||||
seq_max_length=self.args.max_seq_length,
|
"collate_fn": data_collator,
|
||||||
collate_fn=self.data_collator,
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
sampler=train_sampler,
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
}
|
||||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
if self.args.dataloader_prefetch_factor:
|
||||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
dataloader_params[
|
||||||
num_epochs=self.num_epochs,
|
"prefetch_factor"
|
||||||
)
|
] = self.args.dataloader_prefetch_factor
|
||||||
|
|
||||||
|
sampler = self._get_train_sampler()
|
||||||
|
if isinstance(sampler, BatchSampler):
|
||||||
|
dataloader_params["batch_sampler"] = sampler
|
||||||
|
del dataloader_params["batch_size"]
|
||||||
|
else:
|
||||||
|
dataloader_params["sampler"] = sampler
|
||||||
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
dataloader_params["worker_init_fn"] = seed_worker
|
||||||
|
|
||||||
|
self.accelerator.even_batches = False
|
||||||
|
return self.accelerator.prepare_data_loader(
|
||||||
|
DataLoader(train_dataset, **dataloader_params)
|
||||||
)
|
)
|
||||||
return super().get_train_dataloader()
|
return super().get_train_dataloader()
|
||||||
|
|
||||||
@@ -202,18 +224,29 @@ class AxolotlTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||||
return self.accelerator.prepare(
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
MultipackDistributedDataloader(
|
data_collator = self.data_collator
|
||||||
eval_dataset,
|
dataloader_params = {
|
||||||
batch_size=self.args.eval_batch_size,
|
"batch_size": self.args.eval_batch_size,
|
||||||
seq_max_length=self.args.max_seq_length,
|
"collate_fn": data_collator,
|
||||||
collate_fn=self.data_collator,
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
sampler=eval_sampler,
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
}
|
||||||
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
if self.args.dataloader_prefetch_factor:
|
||||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
dataloader_params[
|
||||||
num_epochs=self.num_epochs,
|
"prefetch_factor"
|
||||||
)
|
] = self.args.dataloader_prefetch_factor
|
||||||
|
|
||||||
|
if isinstance(eval_sampler, BatchSampler):
|
||||||
|
dataloader_params["batch_sampler"] = eval_sampler
|
||||||
|
del dataloader_params["batch_size"]
|
||||||
|
else:
|
||||||
|
dataloader_params["sampler"] = eval_sampler
|
||||||
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
|
self.accelerator.even_batches = False
|
||||||
|
return self.accelerator.prepare_data_loader(
|
||||||
|
DataLoader(eval_dataset, **dataloader_params)
|
||||||
)
|
)
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
@@ -234,6 +267,8 @@ class AxolotlTrainer(Trainer):
|
|||||||
"num_workers": self.args.dataloader_num_workers,
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
}
|
}
|
||||||
|
if self.args.dataloader_prefetch_factor:
|
||||||
|
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||||
|
|
||||||
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
||||||
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||||
@@ -251,14 +286,6 @@ class AxolotlTrainer(Trainer):
|
|||||||
# return (loss, outputs) if return_outputs else loss
|
# return (loss, outputs) if return_outputs else loss
|
||||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True, dataloader=None):
|
|
||||||
if self.args.tensor_parallel:
|
|
||||||
model = tp.tensor_parallel(model, distributed=is_distributed())
|
|
||||||
model.hf_device_map = tp.infer_sharded_device_map(model)
|
|
||||||
else:
|
|
||||||
model = super()._wrap_model(model, training=training, dataloader=dataloader)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -384,10 +411,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return trainer_kwargs, trainer_cls
|
return trainer_kwargs, trainer_cls
|
||||||
|
|
||||||
def hook_post_create_trainer(self, trainer):
|
def hook_post_create_trainer(self, trainer):
|
||||||
if self.cfg.tensor_parallel:
|
# TODO
|
||||||
trainer.model = trainer.accelerator.prepare_model(
|
|
||||||
trainer.model, device_placement=True
|
|
||||||
)
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
@@ -459,6 +483,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["fp16"] = (
|
training_arguments_kwargs["fp16"] = (
|
||||||
self.cfg.fp16 and not self.cfg.bf16
|
self.cfg.fp16 and not self.cfg.bf16
|
||||||
) or False
|
) or False
|
||||||
|
if self.cfg.fp8:
|
||||||
|
training_arguments_kwargs["fp16"] = False
|
||||||
|
training_arguments_kwargs["bf16"] = False
|
||||||
|
|
||||||
training_arguments_kwargs["tf32"] = self.cfg.tf32
|
training_arguments_kwargs["tf32"] = self.cfg.tf32
|
||||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||||
@@ -509,6 +537,19 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
"sample_packing_efficiency"
|
"sample_packing_efficiency"
|
||||||
] = self.cfg.sample_packing_eff_est
|
] = self.cfg.sample_packing_eff_est
|
||||||
|
|
||||||
|
if self.cfg.dataloader_pin_memory is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"dataloader_pin_memory"
|
||||||
|
] = self.cfg.dataloader_pin_memory
|
||||||
|
if self.cfg.dataloader_num_workers is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"dataloader_num_workers"
|
||||||
|
] = self.cfg.dataloader_num_workers
|
||||||
|
if self.cfg.dataloader_prefetch_factor is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"dataloader_prefetch_factor"
|
||||||
|
] = self.cfg.dataloader_prefetch_factor
|
||||||
|
|
||||||
if self.cfg.eval_steps:
|
if self.cfg.eval_steps:
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
@@ -631,8 +672,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
] = self.cfg.micro_batch_size
|
] = self.cfg.micro_batch_size
|
||||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||||
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
||||||
training_arguments_kwargs["tensor_parallel"] = self.cfg.tensor_parallel is True
|
|
||||||
|
|
||||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||||
training_arguments_kwargs
|
training_arguments_kwargs
|
||||||
)
|
)
|
||||||
@@ -690,7 +729,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
eval_dataset=self.eval_dataset,
|
eval_dataset=self.eval_dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
data_collator=DataCollatorForSeq2Seq(
|
data_collator=BatchSamplerDataCollatorForSeq2Seq(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
**data_collator_kwargs,
|
**data_collator_kwargs,
|
||||||
@@ -708,4 +747,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
for callback in self.get_post_trainer_create_callbacks(trainer):
|
for callback in self.get_post_trainer_create_callbacks(trainer):
|
||||||
trainer.add_callback(callback)
|
trainer.add_callback(callback)
|
||||||
|
|
||||||
|
if self.cfg.deepspeed and self.cfg.sample_packing:
|
||||||
|
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
|
||||||
|
"train_micro_batch_size_per_gpu"
|
||||||
|
] = self.cfg.micro_batch_size
|
||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
@@ -30,14 +30,20 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
self,
|
self,
|
||||||
prompt_tokenizer: PromptTokenizingStrategy,
|
prompt_tokenizer: PromptTokenizingStrategy,
|
||||||
dataset: IterableDataset,
|
dataset: IterableDataset,
|
||||||
|
process_count: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.prompt_tokenizer = prompt_tokenizer
|
self.prompt_tokenizer = prompt_tokenizer
|
||||||
|
self.process_count = process_count
|
||||||
super().__init__(self.process(dataset).data, **kwargs)
|
super().__init__(self.process(dataset).data, **kwargs)
|
||||||
|
|
||||||
def process(self, dataset):
|
def process(self, dataset):
|
||||||
features = dataset.features.keys()
|
features = dataset.features.keys()
|
||||||
num_proc = min(64, os.cpu_count())
|
num_proc = (
|
||||||
|
min(64, self.process_count)
|
||||||
|
if self.process_count
|
||||||
|
else min(64, os.cpu_count())
|
||||||
|
)
|
||||||
map_kwargs = {}
|
map_kwargs = {}
|
||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
map_kwargs["batched"] = True
|
map_kwargs["batched"] = True
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
@@ -10,6 +9,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers.modelcard
|
import transformers.modelcard
|
||||||
|
from accelerate.logging import get_logger
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
@@ -26,7 +26,7 @@ src_dir = os.path.join(project_root, "src")
|
|||||||
sys.path.insert(0, src_dir)
|
sys.path.insert(0, src_dir)
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
LOG = logging.getLogger("axolotl.train")
|
LOG = get_logger("axolotl.train")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -44,7 +44,10 @@ def train(
|
|||||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||||
):
|
):
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
LOG.debug(
|
||||||
|
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||||
|
main_process_only=True,
|
||||||
|
)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
|
||||||
train_dataset = dataset_meta.train_dataset
|
train_dataset = dataset_meta.train_dataset
|
||||||
@@ -52,7 +55,10 @@ def train(
|
|||||||
total_num_steps = dataset_meta.total_num_steps
|
total_num_steps = dataset_meta.total_num_steps
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
LOG.info("loading model and (optionally) peft_config...")
|
msg = "loading model"
|
||||||
|
if cfg.adapter:
|
||||||
|
msg += " and peft_config..."
|
||||||
|
LOG.debug(msg)
|
||||||
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||||
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|||||||
@@ -1,13 +1,10 @@
|
|||||||
"""Benchmarking and measurement utilities"""
|
"""Benchmarking and measurement utilities"""
|
||||||
import functools
|
import functools
|
||||||
import logging
|
|
||||||
|
|
||||||
import pynvml
|
import pynvml
|
||||||
import torch
|
import torch
|
||||||
from pynvml.nvml import NVMLError
|
from pynvml.nvml import NVMLError
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.bench")
|
|
||||||
|
|
||||||
|
|
||||||
def check_cuda_device(default_value):
|
def check_cuda_device(default_value):
|
||||||
"""
|
"""
|
||||||
@@ -65,14 +62,7 @@ def gpu_memory_usage_smi(device=0):
|
|||||||
|
|
||||||
|
|
||||||
def log_gpu_memory_usage(log, msg, device):
|
def log_gpu_memory_usage(log, msg, device):
|
||||||
if not torch.cuda.is_available():
|
usage, cache, misc = gpu_memory_usage_all(device)
|
||||||
return (0, 0, 0)
|
|
||||||
|
|
||||||
try:
|
|
||||||
usage, cache, misc = gpu_memory_usage_all(device)
|
|
||||||
except ValueError as exc:
|
|
||||||
LOG.exception(exc)
|
|
||||||
return (0, 0, 0)
|
|
||||||
extras = []
|
extras = []
|
||||||
if cache > 0:
|
if cache > 0:
|
||||||
extras.append(f"+{cache:.03f}GB cache")
|
extras.append(f"+{cache:.03f}GB cache")
|
||||||
|
|||||||
@@ -119,3 +119,30 @@ class DataCollatorForSeq2Seq:
|
|||||||
features["decoder_input_ids"] = decoder_input_ids
|
features["decoder_input_ids"] = decoder_input_ids
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
|
"""
|
||||||
|
Collator for multipack specific to the using the BatchSampler
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, features, return_tensors=None):
|
||||||
|
chunked_data = {}
|
||||||
|
for feature in features[0].keys():
|
||||||
|
if feature == "length":
|
||||||
|
continue
|
||||||
|
if feature == "attention_mask":
|
||||||
|
arrays = [
|
||||||
|
(1) * np.array(item[feature])
|
||||||
|
for item in features
|
||||||
|
if feature in item
|
||||||
|
]
|
||||||
|
chunked_data[feature] = np.concatenate(arrays)
|
||||||
|
else:
|
||||||
|
arrays = [
|
||||||
|
np.array(item[feature]) for item in features if feature in item
|
||||||
|
]
|
||||||
|
chunked_data[feature] = np.concatenate(arrays)
|
||||||
|
features = [chunked_data]
|
||||||
|
return super().__call__(features, return_tensors=return_tensors)
|
||||||
|
|||||||
@@ -70,7 +70,9 @@ def normalize_config(cfg):
|
|||||||
else:
|
else:
|
||||||
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
||||||
|
|
||||||
if cfg.bf16 or cfg.bfloat16:
|
if cfg.fp8:
|
||||||
|
cfg.torch_dtype = torch.bfloat16
|
||||||
|
elif cfg.bf16 or cfg.bfloat16:
|
||||||
cfg.torch_dtype = torch.bfloat16
|
cfg.torch_dtype = torch.bfloat16
|
||||||
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
||||||
cfg.torch_dtype = torch.float16
|
cfg.torch_dtype = torch.float16
|
||||||
@@ -369,10 +371,6 @@ def validate_config(cfg):
|
|||||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.tensor_parallel and cfg.gradient_checkpointing:
|
|
||||||
raise ValueError(
|
|
||||||
"TensorParallelPreTrainedModel does not support gradient checkpointing"
|
|
||||||
)
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -80,11 +80,11 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
)
|
)
|
||||||
if cfg.max_steps:
|
if cfg.max_steps:
|
||||||
total_num_steps = min(
|
total_num_steps = min(
|
||||||
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||||
)
|
)
|
||||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||||
else:
|
else:
|
||||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
||||||
return train_dataset, eval_dataset, total_num_steps, prompters
|
return train_dataset, eval_dataset, total_num_steps, prompters
|
||||||
|
|
||||||
|
|
||||||
@@ -482,10 +482,14 @@ def get_dataset_wrapper(
|
|||||||
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
||||||
)
|
)
|
||||||
dataset_prompter = UnsupportedPrompter()
|
dataset_prompter = UnsupportedPrompter()
|
||||||
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
dataset_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
||||||
dataset_prompter = UnsupportedPrompter()
|
dataset_prompter = UnsupportedPrompter()
|
||||||
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
dataset_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
elif d_base_type == "alpaca":
|
elif d_base_type == "alpaca":
|
||||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||||
@@ -494,7 +498,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "explainchoice":
|
elif d_base_type == "explainchoice":
|
||||||
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
||||||
@@ -504,7 +510,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "concisechoice":
|
elif d_base_type == "concisechoice":
|
||||||
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
||||||
@@ -514,7 +522,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "summarizetldr":
|
elif d_base_type == "summarizetldr":
|
||||||
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
||||||
@@ -524,7 +534,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "jeopardy":
|
elif d_base_type == "jeopardy":
|
||||||
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
||||||
@@ -534,7 +546,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "oasst":
|
elif d_base_type == "oasst":
|
||||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||||
@@ -544,7 +558,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "gpteacher":
|
elif d_base_type == "gpteacher":
|
||||||
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
||||||
@@ -554,7 +570,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "reflection":
|
elif d_base_type == "reflection":
|
||||||
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
||||||
@@ -564,7 +582,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
else:
|
else:
|
||||||
suffix = ""
|
suffix = ""
|
||||||
|
|||||||
@@ -50,6 +50,17 @@ def get_world_size():
|
|||||||
return int(os.getenv("WORLD_SIZE", "1"))
|
return int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def zero_only():
|
||||||
|
"""
|
||||||
|
Context manager that only runs the enclosed block on the main rank.
|
||||||
|
"""
|
||||||
|
if is_main_process():
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
yield None
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def zero_first(is_main):
|
def zero_first(is_main):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Optional, Tuple # noqa: F401
|
|||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import transformers.utils.bitsandbytes
|
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from peft import PeftConfig, prepare_model_for_kbit_training
|
from peft import PeftConfig, prepare_model_for_kbit_training
|
||||||
from peft.tuners.lora import QuantLinear
|
from peft.tuners.lora import QuantLinear
|
||||||
@@ -222,7 +221,7 @@ def load_model(
|
|||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
llm_int8_threshold=6.0,
|
llm_int8_threshold=6.0,
|
||||||
llm_int8_has_fp16_weight=False,
|
llm_int8_has_fp16_weight=False,
|
||||||
bnb_4bit_compute_dtype=torch.float16,
|
bnb_4bit_compute_dtype=cfg.torch_dtype,
|
||||||
bnb_4bit_use_double_quant=True,
|
bnb_4bit_use_double_quant=True,
|
||||||
bnb_4bit_quant_type="nf4",
|
bnb_4bit_quant_type="nf4",
|
||||||
)
|
)
|
||||||
@@ -236,12 +235,7 @@ def load_model(
|
|||||||
model_kwargs["use_flash_attention_2"] = True
|
model_kwargs["use_flash_attention_2"] = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if (
|
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||||
cfg.is_llama_derived_model
|
|
||||||
and not cfg.trust_remote_code
|
|
||||||
and not cfg.gptq
|
|
||||||
and not cfg.tensor_parallel
|
|
||||||
):
|
|
||||||
from transformers import LlamaForCausalLM
|
from transformers import LlamaForCausalLM
|
||||||
|
|
||||||
config_kwargs = {}
|
config_kwargs = {}
|
||||||
@@ -307,7 +301,7 @@ def load_model(
|
|||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
elif model_type and not cfg.trust_remote_code and not cfg.tensor_parallel:
|
elif model_type and not cfg.trust_remote_code:
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
@@ -322,17 +316,6 @@ def load_model(
|
|||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
elif cfg.tensor_parallel:
|
|
||||||
model_kwargs.pop("device_map")
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
base_model,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
offload_state_dict=True,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
@@ -383,18 +366,15 @@ def load_model(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
embeddings_len = (
|
||||||
embeddings_len = (
|
math.ceil(len(tokenizer) / 32) * 32
|
||||||
math.ceil(len(tokenizer) / 32) * 32
|
if cfg.resize_token_embeddings_to_32x
|
||||||
if cfg.resize_token_embeddings_to_32x
|
else len(tokenizer)
|
||||||
else len(tokenizer)
|
)
|
||||||
)
|
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
||||||
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
model.resize_token_embeddings(embeddings_len)
|
||||||
model.resize_token_embeddings(embeddings_len)
|
else:
|
||||||
else:
|
model.tie_weights()
|
||||||
model.tie_weights()
|
|
||||||
except NotImplementedError:
|
|
||||||
LOG.warning("`resize_token_embeddings` not implemented on model")
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model.config, "max_position_embeddings")
|
hasattr(model.config, "max_position_embeddings")
|
||||||
@@ -462,14 +442,7 @@ def load_model(
|
|||||||
if cfg.ddp and not load_in_8bit:
|
if cfg.ddp and not load_in_8bit:
|
||||||
model.to(f"cuda:{cfg.local_rank}")
|
model.to(f"cuda:{cfg.local_rank}")
|
||||||
|
|
||||||
if (
|
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
||||||
torch.cuda.device_count() > 1
|
|
||||||
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
|
||||||
and (cfg.load_in_4bit)
|
|
||||||
):
|
|
||||||
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
|
||||||
# so let's only set it for the 4bit, see
|
|
||||||
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
|
||||||
setattr(model, "is_parallelizable", True)
|
setattr(model, "is_parallelizable", True)
|
||||||
setattr(model, "model_parallel", True)
|
setattr(model, "model_parallel", True)
|
||||||
|
|
||||||
@@ -497,12 +470,7 @@ def load_adapter(model, cfg, adapter, inference=False):
|
|||||||
if adapter is None:
|
if adapter is None:
|
||||||
return model, None
|
return model, None
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
try:
|
model.enable_input_require_grads()
|
||||||
model.enable_input_require_grads()
|
|
||||||
except NotImplementedError:
|
|
||||||
LOG.warning("enable_input_require_grads not implemented on model")
|
|
||||||
if adapter == "qlora" and cfg.tensor_parallel:
|
|
||||||
model, _ = load_tp_qlora(model)
|
|
||||||
if adapter in ["lora", "qlora"]:
|
if adapter in ["lora", "qlora"]:
|
||||||
return load_lora(model, cfg, inference=inference)
|
return load_lora(model, cfg, inference=inference)
|
||||||
if adapter == "llama-adapter":
|
if adapter == "llama-adapter":
|
||||||
@@ -554,25 +522,6 @@ def find_all_linear_names(model):
|
|||||||
return list(lora_module_names)
|
return list(lora_module_names)
|
||||||
|
|
||||||
|
|
||||||
def load_tp_qlora(model):
|
|
||||||
from transformers.utils.bitsandbytes import replace_with_bnb_linear
|
|
||||||
|
|
||||||
model = replace_with_bnb_linear(
|
|
||||||
model,
|
|
||||||
quantization_config=BitsAndBytesConfig(
|
|
||||||
load_in_4bit=True,
|
|
||||||
llm_int8_threshold=6.0,
|
|
||||||
llm_int8_has_fp16_weight=False,
|
|
||||||
bnb_4bit_compute_dtype=torch.float16,
|
|
||||||
bnb_4bit_use_double_quant=True,
|
|
||||||
bnb_4bit_quant_type="nf4",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
model.is_loaded_in_4bit = True
|
|
||||||
|
|
||||||
return model, None
|
|
||||||
|
|
||||||
|
|
||||||
def load_lora(model, cfg, inference=False):
|
def load_lora(model, cfg, inference=False):
|
||||||
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
|
|
||||||
|
|||||||
4
src/axolotl/utils/samplers/__init__.py
Normal file
4
src/axolotl/utils/samplers/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
"""
|
||||||
|
axolotl samplers module
|
||||||
|
"""
|
||||||
|
from .multipack import MultipackBatchSampler # noqa: F401
|
||||||
193
src/axolotl/utils/samplers/multipack.py
Normal file
193
src/axolotl/utils/samplers/multipack.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
# pylint: skip-file
|
||||||
|
"""
|
||||||
|
Multipack Batch Sampler
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from typing import Any, Iterable, List, Union
|
||||||
|
|
||||||
|
import numba
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import BatchSampler, Sampler
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
|
||||||
|
|
||||||
|
|
||||||
|
@numba.njit
|
||||||
|
def ffd_check(a: np.ndarray, c: int, n: int):
|
||||||
|
# First-fit-decreasing bin packing
|
||||||
|
# Check if a[] could fit in n bins with capacity c
|
||||||
|
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
||||||
|
|
||||||
|
a = np.sort(a)[::-1]
|
||||||
|
bins = np.full((n,), c, dtype=a.dtype)
|
||||||
|
for size in a:
|
||||||
|
not_found = True
|
||||||
|
for idx in range(n):
|
||||||
|
if bins[idx] >= size:
|
||||||
|
bins[idx] -= size
|
||||||
|
not_found = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if not_found:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@numba.njit
|
||||||
|
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
||||||
|
# First-fit-decreasing bin packing (with result return)
|
||||||
|
|
||||||
|
indices = np.argsort(a)[::-1]
|
||||||
|
a = a[indices]
|
||||||
|
|
||||||
|
bins: List[Any] = []
|
||||||
|
bins_result: List[Any] = []
|
||||||
|
for a_id, size in enumerate(a):
|
||||||
|
add_new = True
|
||||||
|
for idx in range(len(bins)):
|
||||||
|
if bins[idx] >= size:
|
||||||
|
bins[idx] -= size
|
||||||
|
bins_result[idx].append(indices[a_id] + start_index)
|
||||||
|
add_new = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if add_new:
|
||||||
|
bins.append(c - size)
|
||||||
|
bins_result.append([indices[a_id] + start_index])
|
||||||
|
|
||||||
|
return bins_result
|
||||||
|
|
||||||
|
|
||||||
|
@numba.njit
|
||||||
|
def allocate(
|
||||||
|
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
||||||
|
):
|
||||||
|
# Dynamic batch allocator, similar to Multifit
|
||||||
|
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
||||||
|
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
||||||
|
|
||||||
|
s = 0
|
||||||
|
start_index = 0
|
||||||
|
result = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# binary search [l, r)
|
||||||
|
left = 1
|
||||||
|
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
||||||
|
|
||||||
|
while right - left > 1:
|
||||||
|
mid = (left + right) // 2
|
||||||
|
if ffd_check(lengths[start_index : start_index + mid], c, n):
|
||||||
|
left = mid
|
||||||
|
else:
|
||||||
|
right = mid
|
||||||
|
|
||||||
|
# use length l
|
||||||
|
batch = ffd_with_result(
|
||||||
|
lengths[start_index : start_index + left], c, start_index
|
||||||
|
)
|
||||||
|
assert len(batch) <= n
|
||||||
|
if len(batch) < n:
|
||||||
|
break
|
||||||
|
|
||||||
|
start_index += left
|
||||||
|
s = lengths_cumsum[start_index - 1]
|
||||||
|
|
||||||
|
# add local rank
|
||||||
|
result.append(batch[rank])
|
||||||
|
|
||||||
|
return result, s, len(result) * c * n
|
||||||
|
|
||||||
|
|
||||||
|
class MultipackBatchSampler(BatchSampler):
|
||||||
|
"""
|
||||||
|
Batch Sampler class for multipack
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sampler: Union[Sampler[int], Iterable[int]],
|
||||||
|
batch_size: int,
|
||||||
|
drop_last: bool,
|
||||||
|
batch_max_len: int,
|
||||||
|
lengths: np.ndarray,
|
||||||
|
packing_efficiency_estimate: float = 1.0,
|
||||||
|
):
|
||||||
|
super().__init__(sampler, batch_size, drop_last)
|
||||||
|
self.batch_size = None
|
||||||
|
self.batch_max_len = batch_max_len
|
||||||
|
self.lengths: np.ndarray = lengths
|
||||||
|
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||||
|
|
||||||
|
assert isinstance(self.lengths, np.ndarray)
|
||||||
|
|
||||||
|
self.epoch = 0
|
||||||
|
|
||||||
|
# statistics
|
||||||
|
self.eff_total_used = 0
|
||||||
|
self.eff_total_slots = 0
|
||||||
|
|
||||||
|
def set_epoch(self, epoch: int):
|
||||||
|
self.epoch = epoch
|
||||||
|
|
||||||
|
def generate_batches(self, set_stats=False):
|
||||||
|
indices = [idx for idx in self.sampler]
|
||||||
|
|
||||||
|
lengths = self.lengths[indices]
|
||||||
|
lengths_cumsum = np.cumsum(lengths)
|
||||||
|
|
||||||
|
batches, total_used, total_slots = allocate(
|
||||||
|
lengths=lengths,
|
||||||
|
lengths_cumsum=lengths_cumsum,
|
||||||
|
rank=0,
|
||||||
|
c=self.batch_max_len,
|
||||||
|
n=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
||||||
|
|
||||||
|
# statistics
|
||||||
|
if set_stats:
|
||||||
|
self.eff_total_used += total_used
|
||||||
|
self.eff_total_slots += total_slots
|
||||||
|
|
||||||
|
return batches
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
batches = self.generate_batches(set_stats=True)
|
||||||
|
return iter(batches)
|
||||||
|
|
||||||
|
def num_batches(self):
|
||||||
|
batches = self.generate_batches(set_stats=True)
|
||||||
|
return len(batches)
|
||||||
|
|
||||||
|
def efficiency(self):
|
||||||
|
return self.eff_total_used / self.eff_total_slots
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
self.num_batches()
|
||||||
|
return self._len_est()
|
||||||
|
|
||||||
|
def _len_est(self):
|
||||||
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
lengths_sum = np.sum(self.lengths)
|
||||||
|
lengths_sum_per_device = lengths_sum // world_size
|
||||||
|
LOG.info(
|
||||||
|
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||||
|
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||||
|
return (
|
||||||
|
world_size
|
||||||
|
* math.floor(
|
||||||
|
0.99
|
||||||
|
* lengths_sum_per_device
|
||||||
|
/ self.packing_efficiency_estimate
|
||||||
|
// self.batch_max_len
|
||||||
|
)
|
||||||
|
- 1
|
||||||
|
)
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Module containing the Trainer class and related functions"""
|
"""Module containing the Trainer class and related functions"""
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
@@ -9,21 +8,15 @@ from typing import List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
import torch.distributed as dist
|
from accelerate.logging import get_logger
|
||||||
from datasets import set_caching_enabled
|
from datasets import set_caching_enabled
|
||||||
from torch.utils.data import DistributedSampler, RandomSampler
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder
|
||||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
||||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
from axolotl.utils.samplers import MultipackBatchSampler
|
||||||
from axolotl.utils.distributed import (
|
|
||||||
is_distributed,
|
|
||||||
is_main_process,
|
|
||||||
reduce_and_broadcast,
|
|
||||||
zero_first,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = get_logger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
@@ -148,19 +141,18 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
|
|
||||||
def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
def calculate_total_num_steps(cfg, train_dataset):
|
||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
# we have to drop anything longer then sequence len otherwise
|
# we have to drop anything longer then sequence len otherwise
|
||||||
# flash attention with position ids fails
|
# flash attention with position ids fails
|
||||||
if not cfg.total_num_tokens:
|
if not cfg.total_num_tokens:
|
||||||
LOG.info("calculating total_num_tokens")
|
|
||||||
total_num_tokens = np.sum(
|
total_num_tokens = np.sum(
|
||||||
train_dataset.data.column("input_ids")
|
train_dataset.data.column("input_ids")
|
||||||
.to_pandas()
|
.to_pandas()
|
||||||
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
||||||
.values
|
.values
|
||||||
)
|
)
|
||||||
LOG.info(f"total_num_tokens: {total_num_tokens}")
|
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
|
||||||
cfg.total_num_tokens = total_num_tokens
|
cfg.total_num_tokens = total_num_tokens
|
||||||
|
|
||||||
if not cfg.total_supervised_tokens:
|
if not cfg.total_supervised_tokens:
|
||||||
@@ -170,7 +162,10 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|||||||
.apply(lambda x: np.sum(np.array(x) != -100))
|
.apply(lambda x: np.sum(np.array(x) != -100))
|
||||||
.sum()
|
.sum()
|
||||||
)
|
)
|
||||||
LOG.info(f"`total_supervised_tokens: {total_supervised_tokens}`")
|
LOG.debug(
|
||||||
|
f"`total_supervised_tokens: {total_supervised_tokens}`",
|
||||||
|
main_process_only=True,
|
||||||
|
)
|
||||||
cfg.total_supervised_tokens = total_supervised_tokens
|
cfg.total_supervised_tokens = total_supervised_tokens
|
||||||
|
|
||||||
if cfg.sample_packing_eff_est:
|
if cfg.sample_packing_eff_est:
|
||||||
@@ -189,41 +184,41 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|||||||
)
|
)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
)
|
)
|
||||||
LOG.info(
|
LOG.debug(
|
||||||
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
|
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}",
|
||||||
|
main_process_only=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if cfg.world_size > 1 and is_distributed():
|
sampler = MultipackBatchSampler(
|
||||||
sampler = DistributedSampler(
|
sampler=RandomSampler(train_dataset),
|
||||||
train_dataset,
|
|
||||||
num_replicas=cfg.world_size,
|
|
||||||
rank=dist.get_rank(),
|
|
||||||
seed=cfg.seed or 42,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sampler = RandomSampler(train_dataset)
|
|
||||||
|
|
||||||
data_loader = MultipackDistributedDataloader(
|
|
||||||
train_dataset,
|
|
||||||
batch_size=cfg.micro_batch_size,
|
batch_size=cfg.micro_batch_size,
|
||||||
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
|
drop_last=True,
|
||||||
collate_fn=DataCollatorForSeq2Seq(
|
batch_max_len=cfg.micro_batch_size
|
||||||
tokenizer,
|
* (cfg.max_packed_sequence_len or cfg.sequence_len),
|
||||||
return_tensors="pt",
|
lengths=(
|
||||||
padding="longest",
|
train_dataset.data.column("position_ids")
|
||||||
|
.to_pandas()
|
||||||
|
.apply(lambda x: x[-1] + 1)
|
||||||
|
.values
|
||||||
),
|
),
|
||||||
sampler=sampler,
|
|
||||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
|
||||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
|
||||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
|
||||||
num_epochs=cfg.num_epochs,
|
|
||||||
)
|
)
|
||||||
data_loader_len = data_loader.len_w_stats()
|
|
||||||
actual_eff = data_loader.efficiency()
|
data_loader = DataLoader(
|
||||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
train_dataset.remove_columns(["length"]),
|
||||||
|
batch_sampler=sampler,
|
||||||
|
)
|
||||||
|
data_loader_len = len(data_loader)
|
||||||
|
actual_eff = sampler.efficiency()
|
||||||
|
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||||
# on the agreed on value for sample_packing_eff_est
|
# on the agreed on value for sample_packing_eff_est
|
||||||
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
total_num_steps = int(
|
||||||
|
math.floor(
|
||||||
|
data_loader_len
|
||||||
|
* cfg.num_epochs
|
||||||
|
/ int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||||
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||||
@@ -237,12 +232,20 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|||||||
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
||||||
)
|
)
|
||||||
cfg.sample_packing_eff_est = sample_packing_eff_est
|
cfg.sample_packing_eff_est = sample_packing_eff_est
|
||||||
LOG.info(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
|
LOG.debug(
|
||||||
|
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
|
||||||
|
main_process_only=True,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(
|
||||||
|
len(train_dataset)
|
||||||
|
* cfg.num_epochs
|
||||||
|
/ int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
/ cfg.batch_size
|
||||||
|
)
|
||||||
)
|
)
|
||||||
LOG.info(f"total_num_steps: {total_num_steps}")
|
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
||||||
return total_num_steps
|
return total_num_steps
|
||||||
|
|
||||||
|
|
||||||
@@ -265,6 +268,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
setup_fsdp_envs(cfg)
|
setup_fsdp_envs(cfg)
|
||||||
elif cfg.deepspeed:
|
elif cfg.deepspeed:
|
||||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||||
|
if cfg.fp8:
|
||||||
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
|
||||||
|
|
||||||
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
||||||
trainer_builder.train_dataset = train_dataset
|
trainer_builder.train_dataset = train_dataset
|
||||||
|
|||||||
0
tests/e2e/__init__.py
Normal file
0
tests/e2e/__init__.py
Normal file
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import tempfile
|
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -16,6 +15,8 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
@@ -25,9 +26,9 @@ class TestFusedLlama(unittest.TestCase):
|
|||||||
Test case for Llama models using Fused layers
|
Test case for Llama models using Fused layers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_fft_packing(self):
|
@with_temp_dir
|
||||||
|
def test_fft_packing(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
output_dir = tempfile.mkdtemp()
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "JackFram/llama-68m",
|
||||||
@@ -51,7 +52,7 @@ class TestFusedLlama(unittest.TestCase):
|
|||||||
"num_epochs": 2,
|
"num_epochs": 2,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": output_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -69,4 +70,4 @@ class TestFusedLlama(unittest.TestCase):
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import tempfile
|
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -14,6 +13,8 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
@@ -23,9 +24,9 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
Test case for Llama models using LoRA
|
Test case for Llama models using LoRA
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_lora(self):
|
@with_temp_dir
|
||||||
|
def test_lora(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
output_dir = tempfile.mkdtemp()
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "JackFram/llama-68m",
|
||||||
@@ -52,7 +53,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"num_epochs": 2,
|
"num_epochs": 2,
|
||||||
"micro_batch_size": 8,
|
"micro_batch_size": 8,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": output_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -63,11 +64,11 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
def test_lora_packing(self):
|
@with_temp_dir
|
||||||
|
def test_lora_packing(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
output_dir = tempfile.mkdtemp()
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "JackFram/llama-68m",
|
||||||
@@ -96,7 +97,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"num_epochs": 2,
|
"num_epochs": 2,
|
||||||
"micro_batch_size": 8,
|
"micro_batch_size": 8,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": output_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -107,11 +108,11 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
def test_lora_gptq(self):
|
@with_temp_dir
|
||||||
|
def test_lora_gptq(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
output_dir = tempfile.mkdtemp()
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
||||||
@@ -144,7 +145,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"save_steps": 0.5,
|
"save_steps": 0.5,
|
||||||
"micro_batch_size": 8,
|
"micro_batch_size": 8,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": output_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -155,4 +156,4 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import tempfile
|
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -16,6 +15,8 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
@@ -25,9 +26,9 @@ class TestMistral(unittest.TestCase):
|
|||||||
Test case for Llama models using LoRA
|
Test case for Llama models using LoRA
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_lora(self):
|
@with_temp_dir
|
||||||
|
def test_lora(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
output_dir = tempfile.mkdtemp()
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||||
@@ -54,7 +55,7 @@ class TestMistral(unittest.TestCase):
|
|||||||
"num_epochs": 2,
|
"num_epochs": 2,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": output_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -68,11 +69,11 @@ class TestMistral(unittest.TestCase):
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
def test_ft(self):
|
@with_temp_dir
|
||||||
|
def test_ft(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
output_dir = tempfile.mkdtemp()
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||||
@@ -93,7 +94,7 @@ class TestMistral(unittest.TestCase):
|
|||||||
"num_epochs": 2,
|
"num_epochs": 2,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": output_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -111,4 +112,4 @@ class TestMistral(unittest.TestCase):
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import tempfile
|
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -16,6 +15,8 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
@@ -25,9 +26,9 @@ class TestMistral(unittest.TestCase):
|
|||||||
Test case for Llama models using LoRA
|
Test case for Llama models using LoRA
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_lora_packing(self):
|
@with_temp_dir
|
||||||
|
def test_lora_packing(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
output_dir = tempfile.mkdtemp()
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||||
@@ -55,7 +56,7 @@ class TestMistral(unittest.TestCase):
|
|||||||
"num_epochs": 2,
|
"num_epochs": 2,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": output_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -69,11 +70,11 @@ class TestMistral(unittest.TestCase):
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
def test_ft_packing(self):
|
@with_temp_dir
|
||||||
|
def test_ft_packing(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
output_dir = tempfile.mkdtemp()
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||||
@@ -95,7 +96,7 @@ class TestMistral(unittest.TestCase):
|
|||||||
"num_epochs": 2,
|
"num_epochs": 2,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": output_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -113,4 +114,4 @@ class TestMistral(unittest.TestCase):
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ E2E tests for lora llama
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import tempfile
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
@@ -13,6 +13,8 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
@@ -22,7 +24,8 @@ class TestPhi(unittest.TestCase):
|
|||||||
Test case for Llama models using LoRA
|
Test case for Llama models using LoRA
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_ft(self):
|
@with_temp_dir
|
||||||
|
def test_ft(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -52,7 +55,7 @@ class TestPhi(unittest.TestCase):
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"micro_batch_size": 1,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": tempfile.mkdtemp(),
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_bnb_8bit",
|
"optimizer": "adamw_bnb_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -64,8 +67,10 @@ class TestPhi(unittest.TestCase):
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||||
|
|
||||||
def test_ft_packed(self):
|
@with_temp_dir
|
||||||
|
def test_ft_packed(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -95,7 +100,7 @@ class TestPhi(unittest.TestCase):
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"micro_batch_size": 1,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": tempfile.mkdtemp(),
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_bnb_8bit",
|
"optimizer": "adamw_bnb_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -107,3 +112,4 @@ class TestPhi(unittest.TestCase):
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||||
|
|||||||
22
tests/e2e/utils.py
Normal file
22
tests/e2e/utils.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""
|
||||||
|
helper utils for tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
|
||||||
|
def with_temp_dir(test_func):
|
||||||
|
@wraps(test_func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Create a temporary directory
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
try:
|
||||||
|
# Pass the temporary directory to the test function
|
||||||
|
test_func(*args, temp_dir=temp_dir, **kwargs)
|
||||||
|
finally:
|
||||||
|
# Clean up the directory after the test
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
Reference in New Issue
Block a user