Compare commits
54 Commits
docker-cle
...
unsloth_mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f9b172c47 | ||
|
|
8671ed5a0c | ||
|
|
538c004080 | ||
|
|
add3b139ed | ||
|
|
a581e9f8f6 | ||
|
|
992e742cdc | ||
|
|
a1da39cd48 | ||
|
|
58ec8b1113 | ||
|
|
476a205cea | ||
|
|
3e3229e2d9 | ||
|
|
1d21aa6b0a | ||
|
|
71b7ea3c05 | ||
|
|
a48dbf6561 | ||
|
|
6a4562ac08 | ||
|
|
1115c501b8 | ||
|
|
7ee3c4cacb | ||
|
|
fb12895a17 | ||
|
|
9fc29e082b | ||
|
|
575a082aae | ||
|
|
ddf815022a | ||
|
|
9bf854e59c | ||
|
|
797f3dd1de | ||
|
|
0de1457189 | ||
|
|
3cc67d2cdd | ||
|
|
1bc11868eb | ||
|
|
b3a61e8ce2 | ||
|
|
8a8d1c4023 | ||
|
|
332984db18 | ||
|
|
48630f5b34 | ||
|
|
b33c1d55a2 | ||
|
|
0c2a630326 | ||
|
|
db8a8afcba | ||
|
|
14706504e3 | ||
|
|
501b4d1379 | ||
|
|
306fe19c54 | ||
|
|
614cff4107 | ||
|
|
1a6309c8a6 | ||
|
|
105d0b350b | ||
|
|
f544ab2bed | ||
|
|
641e6f7e51 | ||
|
|
6dc68a653f | ||
|
|
7de6a5639c | ||
|
|
c74f045ba7 | ||
|
|
0402d19759 | ||
|
|
b2430ce670 | ||
|
|
4c834bf25d | ||
|
|
8056ecd30e | ||
|
|
738a057674 | ||
|
|
cdc71f73c8 | ||
|
|
6459ac7357 | ||
|
|
964d858da0 | ||
|
|
10388a8daf | ||
|
|
9f7e8a971d | ||
|
|
637ed095a0 |
1
.github/workflows/tests.yml
vendored
1
.github/workflows/tests.yml
vendored
@@ -71,6 +71,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
||||
pip3 uninstall -y transformers accelerate
|
||||
pip3 install -U -e .[flash-attn]
|
||||
pip3 install -r requirements-tests.txt
|
||||
|
||||
91
README.md
91
README.md
@@ -25,8 +25,10 @@ Features:
|
||||
- [Installation](#installation)
|
||||
- [Docker](#docker)
|
||||
- [Conda/Pip venv](#condapip-venv)
|
||||
- [Runpod](#runpod)
|
||||
- [LambdaLabs](#lambdalabs)
|
||||
- [Windows](#windows)
|
||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||
- [Dataset](#dataset)
|
||||
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
||||
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
||||
@@ -74,6 +76,8 @@ Features:
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
|
||||
|
||||
## Quickstart ⚡
|
||||
@@ -82,20 +86,29 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
|
||||
|
||||
**Requirements**: Python >=3.9 and Pytorch >=2.0.
|
||||
|
||||
`pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"`
|
||||
|
||||
### For developers
|
||||
```bash
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||
```
|
||||
|
||||
### Usage
|
||||
```bash
|
||||
# finetune lora
|
||||
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||
|
||||
# inference
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--lora_model_dir="./lora-out"
|
||||
|
||||
# gradio
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--lora_model_dir="./lora-out" --gradio
|
||||
```
|
||||
|
||||
## Installation
|
||||
@@ -106,7 +119,6 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
```
|
||||
- `winglian/axolotl-runpod:main-latest`: for runpod or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
|
||||
Or run on the current files for development:
|
||||
|
||||
@@ -121,13 +133,15 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
A more powerful Docker command to run would be this:
|
||||
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
```
|
||||
|
||||
It additionally:
|
||||
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
|
||||
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
|
||||
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
|
||||
* The `--privileged` flag gives all capabilities to the container.
|
||||
* The `--shm-size 10g` argument increases the shared memory size. Use this if you see `exitcode: -7` errors using deepspeed.
|
||||
|
||||
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
|
||||
|
||||
@@ -149,6 +163,10 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
```
|
||||
Get the token at huggingface.co/settings/tokens
|
||||
|
||||
#### Runpod
|
||||
|
||||
Use `winglian/axolotl-runpod:main-latest` or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
|
||||
#### LambdaLabs
|
||||
<details>
|
||||
|
||||
@@ -196,6 +214,28 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
#### Windows
|
||||
Please use WSL or Docker!
|
||||
|
||||
|
||||
#### Launching on public clouds via SkyPilot
|
||||
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
|
||||
```bash
|
||||
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
|
||||
sky check
|
||||
```
|
||||
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
|
||||
```
|
||||
git clone https://github.com/skypilot-org/skypilot.git
|
||||
cd skypilot/llm/axolotl
|
||||
```
|
||||
Use one command to launch:
|
||||
```bash
|
||||
# On-demand
|
||||
HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
|
||||
|
||||
# Managed spot (auto-recovery on preemption)
|
||||
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
||||
```
|
||||
|
||||
|
||||
### Dataset
|
||||
|
||||
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
|
||||
@@ -392,6 +432,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- path: knowrohit07/know_sql
|
||||
type: context_qa.load_v2
|
||||
train_on_split: validation
|
||||
|
||||
# loading from s3 or gcs
|
||||
# s3 creds will be loaded from the system default and gcs only supports public access
|
||||
dataset:
|
||||
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
|
||||
...
|
||||
```
|
||||
|
||||
- loading
|
||||
@@ -454,6 +500,15 @@ is_falcon_derived_model:
|
||||
is_llama_derived_model:
|
||||
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
||||
is_mistral_derived_model:
|
||||
is_qwen_derived_model:
|
||||
|
||||
# optional overrides to the base model configuration
|
||||
model_config:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
|
||||
# Whether you are training a 4-bit GPTQ quantized model
|
||||
gptq: true
|
||||
@@ -478,7 +533,7 @@ float16: true
|
||||
|
||||
# A list of one or more datasets to finetune the model with
|
||||
datasets:
|
||||
# HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files
|
||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
@@ -486,9 +541,12 @@ datasets:
|
||||
data_files: # Optional[str] path to source data files
|
||||
shards: # Optional[int] number of shards to split data into
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
train_on_split: train # Optional[str] name of dataset split to load from
|
||||
|
||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
field_human: # Optional[str]. Human key to use for conversation.
|
||||
field_model: # Optional[str]. Assistant key to use for conversation.
|
||||
|
||||
# Custom user prompt
|
||||
- path: repo
|
||||
@@ -554,6 +612,12 @@ eval_sample_packing:
|
||||
sample_packing_eff_est:
|
||||
total_num_tokens:
|
||||
|
||||
# Passed through to transformers when loading the model when launched without accelerate
|
||||
# Use `sequential` when training w/ model parallelism to limit memory
|
||||
device_map:
|
||||
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
|
||||
max_memory:
|
||||
|
||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
adapter: lora
|
||||
# If you already have a lora model trained that you want to load, put that here.
|
||||
@@ -601,7 +665,8 @@ wandb_mode: # "offline" to save run metadata locally and not sync to the server,
|
||||
wandb_project: # Your wandb project name
|
||||
wandb_entity: # A wandb Team name if using a Team
|
||||
wandb_watch:
|
||||
wandb_run_id: # Set the name of your wandb run
|
||||
wandb_name: # Set the name of your wandb run
|
||||
wandb_run_id: # Set the ID of your wandb run
|
||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||
|
||||
# Where to save the full-finetuned model to
|
||||
@@ -619,7 +684,8 @@ gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
eval_batch_size:
|
||||
num_epochs: 4
|
||||
warmup_steps: 100
|
||||
warmup_steps: 100 # cannot use with warmup_ratio
|
||||
warmup_ratio: 0.05 # cannot use with warmup_steps
|
||||
learning_rate: 0.00003
|
||||
lr_quadratic_warmup:
|
||||
logging_steps:
|
||||
@@ -635,6 +701,9 @@ max_steps:
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
|
||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
||||
|
||||
# Save model as safetensors (require safetensors package)
|
||||
save_safetensors:
|
||||
|
||||
@@ -721,10 +790,6 @@ landmark_attention:
|
||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
# LLaMA only
|
||||
xpos_rope:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
# Resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
@@ -897,7 +962,7 @@ wandb_mode:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
```
|
||||
|
||||
@@ -918,6 +983,10 @@ Pass the appropriate flag to the train command:
|
||||
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
|
||||
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
||||
```
|
||||
-- With gradio hosting
|
||||
```bash
|
||||
python -m axolotl.cli.inference examples/your_config.yml --gradio
|
||||
```
|
||||
|
||||
Please use `--sample_packing False` if you have it on and receive the error similar to below:
|
||||
|
||||
|
||||
@@ -24,16 +24,6 @@
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
|
||||
@@ -28,16 +28,6 @@
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
|
||||
@@ -32,16 +32,6 @@
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
|
||||
@@ -8,7 +8,6 @@ ENV BNB_CUDA_VERSION=$CUDA
|
||||
ARG PYTORCH_VERSION="2.0.1"
|
||||
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER=1
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y vim curl
|
||||
@@ -22,9 +21,9 @@ WORKDIR /workspace/axolotl
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
||||
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
|
||||
else \
|
||||
pip install -e .[flash-attn]; \
|
||||
pip install -e .[deepspeed,flash-attn]; \
|
||||
fi
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
|
||||
@@ -10,11 +10,13 @@ ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
ARG PYTHON_VERSION="3.9"
|
||||
ARG PYTORCH_VERSION="2.0.1"
|
||||
ARG CUDA="118"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev unzip && rm -rf /var/lib/apt/lists/* \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
@@ -27,47 +29,9 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} deepspeed-kernels --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||
|
||||
FROM base-builder AS deepspeed-builder
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
|
||||
cd DeepSpeed && \
|
||||
MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 DS_BUILD_EVOFORMER_ATTN=0 python3 setup.py bdist_wheel
|
||||
|
||||
FROM base-builder AS bnb-builder
|
||||
|
||||
WORKDIR /workspace
|
||||
ARG CUDA="118"
|
||||
ENV CUDA=$CUDA
|
||||
ARG MAX_JOBS="-1"
|
||||
ENV MAX_JOBS=$MAX_JOBS
|
||||
|
||||
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
|
||||
cd bitsandbytes && \
|
||||
CUDA_VERSION=$CUDA make cuda11x && \
|
||||
python setup.py bdist_wheel
|
||||
|
||||
FROM base-builder
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN mkdir -p /workspace/builds
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
||||
|
||||
RUN mkdir -p /workspace/wheels/bitsandbytes
|
||||
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
|
||||
|
||||
RUN pip3 install wheels/deepspeed-*.whl
|
||||
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
|
||||
RUN git lfs install --skip-repo
|
||||
RUN pip3 install awscli && \
|
||||
RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
@@ -12,3 +12,7 @@ This usually happens when you run out of system RAM.
|
||||
> Exitcode -7 while using deepspeed
|
||||
|
||||
Try upgrading deepspeed w: `pip install -U deepspeed`
|
||||
|
||||
> AttributeError: 'DummyOptim' object has no attribute 'step'
|
||||
|
||||
You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
|
||||
|
||||
@@ -14,7 +14,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_prepared_run
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
output_dir: btlm-out
|
||||
|
||||
@@ -7,7 +7,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -24,7 +24,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
batch_size: 4
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
|
||||
@@ -12,7 +12,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca:chat
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./falcon-7b
|
||||
batch_size: 2
|
||||
|
||||
@@ -18,7 +18,7 @@ datasets:
|
||||
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
||||
type: "alpaca:chat"
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
@@ -40,7 +40,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca:chat
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./falcon-7b
|
||||
batch_size: 2
|
||||
|
||||
@@ -7,7 +7,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
gradient_accumulation_steps: 2
|
||||
|
||||
@@ -19,7 +19,7 @@ lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./jeopardy-bot-7b
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -15,7 +15,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 4096
|
||||
@@ -32,7 +32,7 @@ lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./model-out
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./relora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -35,7 +35,7 @@ relora_cpu_offload: false
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
base_model: PY007/TinyLlama-1.1B-step-50K-105b
|
||||
base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T
|
||||
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
@@ -12,7 +12,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 8192
|
||||
@@ -21,7 +21,7 @@ pad_to_sequence_len: true
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -38,7 +38,7 @@ lora_target_modules:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
@@ -62,6 +62,9 @@ logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
|
||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out: false
|
||||
wandb_project: mpt-alpaca-7b
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./mpt-alpaca-7b
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./openllama-out
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-out
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -9,7 +9,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 1024
|
||||
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
base_model: microsoft/phi-1_5
|
||||
model_type: MixFormerSequentialForCausalLM
|
||||
model_type: PhiForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_llama_derived_model: false
|
||||
trust_remote_code: true
|
||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./pythia-12b
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-alpaca-pythia
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
68
examples/qwen/lora.yml
Normal file
68
examples/qwen/lora.yml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: Qwen/Qwen-7B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
is_qwen_derived_model: true
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 2048 # supports up to 8192
|
||||
sample_packing: false
|
||||
pad_to_sequence_len:
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
68
examples/qwen/qlora.yml
Normal file
68
examples/qwen/qlora.yml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: Qwen/Qwen-7B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
is_qwen_derived_model: true
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 2048 # supports up to 8192
|
||||
sample_packing: false
|
||||
pad_to_sequence_len:
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -22,7 +22,7 @@ lora_fan_in_fan_out: false
|
||||
wandb_project: redpajama-alpaca-3b
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./redpajama-alpaca-3b
|
||||
batch_size: 4
|
||||
|
||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project: lora-replit
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-replit
|
||||
batch_size: 8
|
||||
|
||||
@@ -16,7 +16,7 @@ datasets:
|
||||
- openassistant_best_replies_train.jsonl
|
||||
type: "completion"
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
@@ -38,7 +38,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# Page
|
||||
@@ -1,4 +0,0 @@
|
||||
# Table of contents
|
||||
|
||||
* [Page](README.md)
|
||||
* [Small dev details](small-dev-details.md)
|
||||
@@ -1,3 +0,0 @@
|
||||
# Small dev details
|
||||
|
||||
/
|
||||
@@ -1,23 +1,22 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
torch==2.0.1
|
||||
auto-gptq
|
||||
auto-gptq==0.5.1
|
||||
packaging
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697
|
||||
peft==0.6.0
|
||||
transformers==4.35.2
|
||||
tokenizers==0.15.0
|
||||
bitsandbytes>=0.41.1
|
||||
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
|
||||
accelerate==0.24.1
|
||||
deepspeed
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
datasets
|
||||
flash-attn>=2.3.0
|
||||
datasets>=2.15.0
|
||||
flash-attn==2.3.3
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
xformers>=0.0.22
|
||||
optimum
|
||||
xformers==0.0.22
|
||||
optimum==1.13.2
|
||||
hf_transfer
|
||||
colorama
|
||||
numba
|
||||
@@ -31,3 +30,10 @@ scikit-learn==1.2.2
|
||||
pynvml
|
||||
art
|
||||
fschat==0.2.29
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
|
||||
# remote filesystems
|
||||
s3fs
|
||||
gcsfs
|
||||
# adlfs
|
||||
|
||||
@@ -6,8 +6,10 @@ import os
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
@@ -16,7 +18,7 @@ from accelerate.commands.config import config_args
|
||||
from art import text2art
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.logging_config import configure_logging
|
||||
@@ -27,6 +29,7 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
@@ -44,7 +47,7 @@ def print_axolotl_text_art(suffix=None):
|
||||
ascii_text = " axolotl"
|
||||
if suffix:
|
||||
ascii_text += f" x {suffix}"
|
||||
ascii_art = text2art(" axolotl", font=font)
|
||||
ascii_art = text2art(ascii_text, font=font)
|
||||
|
||||
if is_main_process():
|
||||
print(ascii_art)
|
||||
@@ -69,7 +72,7 @@ def do_merge_lora(
|
||||
|
||||
LOG.info("running merge of LoRA with base model")
|
||||
model = model.merge_and_unload()
|
||||
model.to(dtype=torch.float16)
|
||||
model.to(dtype=cfg.torch_dtype)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
||||
@@ -153,6 +156,91 @@ def do_inference(
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
def do_inference_gradio(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
prompter = cli_args.prompter
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
if prompter:
|
||||
prompter_module = getattr(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
|
||||
if cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
model.set_mem_cache_args(
|
||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||
)
|
||||
|
||||
model = model.to(cfg.device)
|
||||
|
||||
def generate(instruction):
|
||||
if not instruction:
|
||||
return
|
||||
if prompter_module:
|
||||
# pylint: disable=stop-iteration-return
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
else:
|
||||
prompt = instruction.strip()
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=1024,
|
||||
temperature=0.9,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
output_scores=False,
|
||||
)
|
||||
streamer = TextIteratorStreamer(tokenizer)
|
||||
generation_kwargs = {
|
||||
"inputs": batch["input_ids"].to(cfg.device),
|
||||
"generation_config": generation_config,
|
||||
"streamer": streamer,
|
||||
}
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
all_text = ""
|
||||
|
||||
for new_text in streamer:
|
||||
all_text += new_text
|
||||
yield all_text
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate,
|
||||
inputs="textbox",
|
||||
outputs="text",
|
||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||
)
|
||||
demo.queue().launch(show_api=False, share=True)
|
||||
|
||||
|
||||
def choose_config(path: Path):
|
||||
yaml_files = list(path.glob("*.yml"))
|
||||
|
||||
@@ -209,6 +297,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
prepare_optim_env(cfg)
|
||||
|
||||
normalize_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
|
||||
@@ -6,11 +6,16 @@ from pathlib import Path
|
||||
import fire
|
||||
import transformers
|
||||
|
||||
from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art
|
||||
from axolotl.cli import (
|
||||
do_inference,
|
||||
do_inference_gradio,
|
||||
load_cfg,
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
@@ -21,7 +26,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
)
|
||||
parsed_cli_args.inference = True
|
||||
|
||||
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
if gradio:
|
||||
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -6,33 +6,33 @@ import abc
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import Dataset
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_pt_utils import SequentialDistributedSampler
|
||||
from transformers.trainer_utils import seed_worker
|
||||
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
GPUStatsCallback,
|
||||
LossWatchDogCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
log_prediction_callback_factory,
|
||||
)
|
||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.samplers import MultipackBatchSampler
|
||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||
|
||||
try:
|
||||
@@ -102,6 +102,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
@@ -145,70 +149,102 @@ class AxolotlTrainer(Trainer):
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size > 1 and self.args.sample_packing:
|
||||
return DistributedSampler(
|
||||
self.train_dataset,
|
||||
num_replicas=self.args.world_size,
|
||||
rank=self.args.process_index,
|
||||
seed=self.args.seed,
|
||||
if self.args.sample_packing:
|
||||
return MultipackBatchSampler(
|
||||
RandomSampler(self.train_dataset),
|
||||
self.args.train_batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=self._train_batch_size * self.args.max_seq_length,
|
||||
lengths=(
|
||||
self.train_dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if (
|
||||
self.args.world_size > 1
|
||||
and self.args.sample_packing
|
||||
and self.args.eval_sample_packing is not False
|
||||
):
|
||||
return SequentialDistributedSampler(
|
||||
eval_dataset,
|
||||
num_replicas=self.args.world_size,
|
||||
rank=self.args.process_index,
|
||||
batch_size=self.args.per_device_eval_batch_size,
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
return MultipackBatchSampler(
|
||||
SequentialSampler(eval_dataset),
|
||||
self.args.per_device_eval_batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
|
||||
lengths=(
|
||||
eval_dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.args.sample_packing:
|
||||
train_sampler = self._get_train_sampler()
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
self.train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=train_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
num_epochs=self.num_epochs,
|
||||
)
|
||||
train_dataset = self.train_dataset
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
sampler = self._get_train_sampler()
|
||||
if isinstance(sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(
|
||||
self, eval_dataset: Optional[Dataset] = None
|
||||
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
eval_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=eval_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
num_epochs=self.num_epochs,
|
||||
)
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if isinstance(eval_sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = eval_sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = eval_sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(eval_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
@@ -222,13 +258,15 @@ class AxolotlTrainer(Trainer):
|
||||
def get_bench_dataloader(
|
||||
self,
|
||||
bench_dataset: Dataset,
|
||||
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
) -> DataLoader:
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": self.bench_data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||
@@ -393,6 +431,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
@@ -424,11 +465,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
warmup_steps = (
|
||||
self.cfg.warmup_steps
|
||||
if self.cfg.warmup_steps is not None
|
||||
else min(int(0.03 * total_num_steps), 100)
|
||||
)
|
||||
warmup_steps = None
|
||||
if self.cfg.warmup_steps is not None:
|
||||
warmup_steps = self.cfg.warmup_steps
|
||||
elif self.cfg.warmup_ratio is not None:
|
||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||
else:
|
||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||
|
||||
logging_steps = (
|
||||
self.cfg.logging_steps
|
||||
if self.cfg.logging_steps is not None
|
||||
@@ -493,16 +537,29 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
"sample_packing_efficiency"
|
||||
] = self.cfg.sample_packing_eff_est
|
||||
|
||||
if self.cfg.eval_steps:
|
||||
if self.cfg.dataloader_pin_memory is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_pin_memory"
|
||||
] = self.cfg.dataloader_pin_memory
|
||||
if self.cfg.dataloader_num_workers is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_num_workers"
|
||||
] = self.cfg.dataloader_num_workers
|
||||
if self.cfg.dataloader_prefetch_factor is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_prefetch_factor"
|
||||
] = self.cfg.dataloader_prefetch_factor
|
||||
|
||||
if self.cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||
elif self.cfg.eval_steps:
|
||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
elif self.cfg.evaluation_strategy:
|
||||
training_arguments_kwargs[
|
||||
"evaluation_strategy"
|
||||
] = self.cfg.evaluation_strategy
|
||||
elif self.cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||
else:
|
||||
# we have an eval set, but no steps defined, default to use epoch
|
||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||
@@ -590,7 +647,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
||||
training_arguments_kwargs["run_name"] = (
|
||||
self.cfg.wandb_run_id if self.cfg.use_wandb else None
|
||||
self.cfg.wandb_name if self.cfg.use_wandb else None
|
||||
)
|
||||
training_arguments_kwargs["optim"] = (
|
||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||
@@ -608,7 +665,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||
)
|
||||
training_arguments_kwargs["eval_sample_packing"] = (
|
||||
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||
self.cfg.sample_packing
|
||||
if self.cfg.eval_sample_packing is not False
|
||||
else False
|
||||
)
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_seq_len_multiplier"
|
||||
@@ -672,7 +731,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
args=training_args,
|
||||
data_collator=DataCollatorForSeq2Seq(
|
||||
data_collator=BatchSamplerDataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
@@ -690,4 +749,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
for callback in self.get_post_trainer_create_callbacks(trainer):
|
||||
trainer.add_callback(callback)
|
||||
|
||||
if self.cfg.deepspeed and self.cfg.sample_packing:
|
||||
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
|
||||
"train_micro_batch_size_per_gpu"
|
||||
] = self.cfg.micro_batch_size
|
||||
|
||||
return trainer
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
@@ -30,14 +30,20 @@ class TokenizedPromptDataset(Dataset):
|
||||
self,
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: IterableDataset,
|
||||
process_count: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.prompt_tokenizer = prompt_tokenizer
|
||||
self.process_count = process_count
|
||||
super().__init__(self.process(dataset).data, **kwargs)
|
||||
|
||||
def process(self, dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = min(64, os.cpu_count())
|
||||
num_proc = (
|
||||
min(64, self.process_count)
|
||||
if self.process_count
|
||||
else min(64, os.cpu_count())
|
||||
)
|
||||
map_kwargs = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
|
||||
@@ -3,4 +3,6 @@ MixFormers model architecture used for phi models
|
||||
"""
|
||||
|
||||
from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
|
||||
from .configuration_phi import PhiConfig # noqa
|
||||
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
|
||||
from .modeling_phi import PhiForCausalLM # noqa
|
||||
|
||||
65
src/axolotl/models/phi/configuration_phi.py
Normal file
65
src/axolotl/models/phi/configuration_phi.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# pylint: skip-file
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class PhiConfig(PretrainedConfig):
|
||||
"""Phi configuration."""
|
||||
|
||||
model_type = "phi"
|
||||
attribute_map = {
|
||||
"max_position_embeddings": "n_positions",
|
||||
"hidden_size": "n_embd",
|
||||
"num_attention_heads": "n_head",
|
||||
"num_hidden_layers": "n_layer",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 50304,
|
||||
n_positions: int = 2048,
|
||||
n_embd: int = 1024,
|
||||
n_layer: int = 20,
|
||||
n_inner: Optional[int] = None,
|
||||
n_head: int = 16,
|
||||
n_head_kv: Optional[int] = None,
|
||||
rotary_dim: Optional[int] = 32,
|
||||
activation_function: Optional[str] = "gelu_new",
|
||||
flash_attn: bool = False,
|
||||
flash_rotary: bool = False,
|
||||
fused_dense: bool = False,
|
||||
attn_pdrop: float = 0.0,
|
||||
embd_pdrop: float = 0.0,
|
||||
resid_pdrop: float = 0.0,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
initializer_range: float = 0.02,
|
||||
tie_word_embeddings: bool = False,
|
||||
pad_vocab_size_multiple: int = 64,
|
||||
**kwargs
|
||||
) -> None:
|
||||
self.vocab_size = int(
|
||||
math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||
)
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_inner = n_inner
|
||||
self.n_head = n_head
|
||||
self.n_head_kv = n_head_kv
|
||||
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
||||
self.activation_function = activation_function
|
||||
self.flash_attn = flash_attn
|
||||
self.flash_rotary = flash_rotary
|
||||
self.fused_dense = fused_dense
|
||||
self.attn_pdrop = attn_pdrop
|
||||
self.embd_pdrop = embd_pdrop
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
1063
src/axolotl/models/phi/modeling_phi.py
Normal file
1063
src/axolotl/models/phi/modeling_phi.py
Normal file
File diff suppressed because it is too large
Load Diff
168
src/axolotl/monkeypatch/cross_entropy.py
Normal file
168
src/axolotl/monkeypatch/cross_entropy.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# Adapted from Unsloth
|
||||
# https://github.com/unslothai/unsloth/blob/4b97a810b509c93f44be4c037c7aa18fb8922884/unsloth/kernels/cross_entropy_loss.py
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
MAX_FUSED_SIZE = 65536
|
||||
|
||||
def calculate_settings(n):
|
||||
BLOCK_SIZE = triton.next_power_of_2(n)
|
||||
# CUDA only supports 65536 - 2^16 threads per block
|
||||
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
||||
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
|
||||
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
|
||||
num_warps = 4
|
||||
if BLOCK_SIZE >= 32768: num_warps = 32
|
||||
elif BLOCK_SIZE >= 8192: num_warps = 16
|
||||
elif BLOCK_SIZE >= 2048: num_warps = 8
|
||||
return BLOCK_SIZE, num_warps
|
||||
pass
|
||||
|
||||
@triton.jit
|
||||
def _cross_entropy_forward(logits_ptr, logits_row_stride,
|
||||
loss_ptr,
|
||||
lse_ptr,
|
||||
labels_ptr,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,):
|
||||
"""
|
||||
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
|
||||
Pi = exp(xi) / sum(exp(xi))
|
||||
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
|
||||
= -y [ x - log[sum(exp(x))] ]
|
||||
= y * (log[sum(exp(x))] - x)
|
||||
If y == 0: CE_i = 0
|
||||
If y == 1: CE_i = logsumexp - x
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
logits_ptr += row_idx * logits_row_stride
|
||||
loss_ptr += row_idx
|
||||
lse_ptr += row_idx
|
||||
labels_ptr += row_idx
|
||||
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
# TODO: Fixup int32 locations to int64
|
||||
label_idx = tl.load(labels_ptr).to(tl.int32)
|
||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
||||
max_logits = tl.max(logits, 0)
|
||||
# Maximum stops overflow
|
||||
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
|
||||
tl.store(lse_ptr, lse)
|
||||
|
||||
if label_idx != -100:
|
||||
logits_label = tl.load(logits_ptr + label_idx).to(tl.float32)
|
||||
loss = lse - logits_label
|
||||
else:
|
||||
loss = 0.0
|
||||
tl.store(loss_ptr, loss)
|
||||
pass
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _cross_entropy_backward(logits_ptr, logits_row_stride,
|
||||
dloss_ptr, dloss_row_stride,
|
||||
lse_ptr,
|
||||
labels_ptr,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,):
|
||||
"""
|
||||
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
|
||||
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
|
||||
|
||||
From https://en.wikipedia.org/wiki/LogSumExp
|
||||
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
|
||||
|
||||
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
|
||||
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
|
||||
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
|
||||
|
||||
If y == 0: dC/dx = 0
|
||||
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
|
||||
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
logits_ptr += row_idx * logits_row_stride
|
||||
dloss_ptr += row_idx * dloss_row_stride
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
# TODO: Fixup int32 locations to int64
|
||||
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
|
||||
|
||||
if label_idx != -100:
|
||||
dloss = tl.load(dloss_ptr)
|
||||
else:
|
||||
dloss = 0.0
|
||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
lse = tl.load(lse_ptr + row_idx)
|
||||
probs = tl.exp(logits - lse)
|
||||
|
||||
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
|
||||
tl.store(logits_ptr + col_offsets, dloss * probs, mask = mask)
|
||||
|
||||
|
||||
|
||||
class CrossEntropyLoss(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, logits, labels):
|
||||
n_rows, n_cols = logits.shape
|
||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||
|
||||
_cross_entropy_forward[(n_rows,)](
|
||||
logits, logits.stride(0),
|
||||
losses,
|
||||
logsumexp,
|
||||
labels,
|
||||
n_cols,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
num_warps = num_warps,
|
||||
)
|
||||
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.save_for_backward(logits, logsumexp, labels)
|
||||
return losses
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dlosses):
|
||||
logits, logsumexp, labels = ctx.saved_tensors
|
||||
n_rows, n_cols = logits.shape
|
||||
|
||||
_cross_entropy_backward[(n_rows,)](
|
||||
logits, logits.stride(0),
|
||||
dlosses, dlosses.stride(0),
|
||||
logsumexp,
|
||||
labels,
|
||||
n_cols,
|
||||
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
||||
num_warps = ctx.num_warps,
|
||||
)
|
||||
return logits, None, None,
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def fast_cross_entropy_loss(logits, labels):
|
||||
"""
|
||||
Arguments:
|
||||
logits: (batch, seq_len, vocab_size)
|
||||
labels: (batch, seq_len,)
|
||||
Returns:
|
||||
losses: float
|
||||
"""
|
||||
batch, seq_len, d = logits.shape
|
||||
assert(labels.shape == (batch, seq_len))
|
||||
|
||||
loss = CrossEntropyLoss.apply(
|
||||
logits.view(batch*seq_len, d),
|
||||
labels.view(-1),
|
||||
)
|
||||
n_items = torch.count_nonzero(labels != -100)
|
||||
return loss.sum() / n_items
|
||||
pass
|
||||
@@ -321,6 +321,8 @@ def flashattn_forward(
|
||||
# only on first autoregressive step q,k,v have same seqlen
|
||||
is_causal = key_states.shape == query_states.shape
|
||||
|
||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
@@ -330,7 +332,12 @@ def flashattn_forward(
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
@@ -353,7 +360,7 @@ def flashattn_forward(
|
||||
qkv_unpad,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
0.0,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
)
|
||||
@@ -366,6 +373,7 @@ def flashattn_forward(
|
||||
output = flash_attn_kvpacked_func(
|
||||
query_states,
|
||||
torch.stack([key_states, value_states], 2),
|
||||
dropout_p=dropout_rate,
|
||||
causal=is_causal,
|
||||
)
|
||||
else:
|
||||
@@ -398,7 +406,7 @@ def flashattn_forward(
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
0.0,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
)
|
||||
|
||||
@@ -25,6 +25,8 @@ def sdp_attention_forward(
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@@ -29,6 +29,8 @@ def xformers_forward(
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@@ -13,16 +13,20 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralAttention as OriginalMistralAttention,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralForCausalLM as OriginalMistralForCausalLM,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.monkeypatch.cross_entropy import fast_cross_entropy_loss
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
||||
|
||||
@@ -36,6 +40,9 @@ def replace_mistral_attn_with_flash_attn(
|
||||
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
||||
flashattn_forward
|
||||
)
|
||||
transformers.models.mistral.modeling_mistral.MistralForCausalLM.forward = (
|
||||
mistral_causallm_forward
|
||||
)
|
||||
if packed:
|
||||
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
||||
MistralDecoderLayer
|
||||
@@ -201,6 +208,8 @@ def flashattn_forward(
|
||||
# only on first autoregressive step q,k,v have same seqlen
|
||||
is_causal = key_states.shape == query_states.shape
|
||||
|
||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
@@ -213,7 +222,7 @@ def flashattn_forward(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
0.0,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
@@ -239,7 +248,7 @@ def flashattn_forward(
|
||||
qkv_unpad,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
0.0,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
@@ -253,6 +262,7 @@ def flashattn_forward(
|
||||
output = flash_attn_kvpacked_func(
|
||||
query_states,
|
||||
torch.stack([key_states, value_states], 2),
|
||||
dropout_p=dropout_rate,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
@@ -286,7 +296,7 @@ def flashattn_forward(
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
0.0,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
@@ -638,3 +648,71 @@ class MistralDecoderLayer(OriginalMistralDecoderLayer):
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def mistral_causallm_forward(
|
||||
self: OriginalMistralForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
*args, **kwargs
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
shift_logits = logits
|
||||
if not hasattr(self, "extra_ignored_labels"):
|
||||
self.extra_ignored_labels = torch.full((self.model.config.max_position_embeddings, 1), -100, device=shift_logits.device)
|
||||
|
||||
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
|
||||
# FAST CROSS ENTROPY
|
||||
loss = fast_cross_entropy_loss(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
@@ -22,7 +22,13 @@ class PromptStyle(Enum):
|
||||
CHATML = "chatml"
|
||||
|
||||
|
||||
class AlpacaPrompter:
|
||||
class Prompter:
|
||||
"""
|
||||
Base prompter class for all prompters
|
||||
"""
|
||||
|
||||
|
||||
class AlpacaPrompter(Prompter):
|
||||
"""
|
||||
Base class for alpaca prompters
|
||||
"""
|
||||
@@ -69,7 +75,7 @@ class AlpacaPrompter:
|
||||
else:
|
||||
res = (
|
||||
self.system_format.format(system=self.system_no_input_prompt)
|
||||
if self.system_prompt
|
||||
if self.system_no_input_prompt
|
||||
else ""
|
||||
) + self.turn_no_input_format.format(instruction=instruction)
|
||||
if output:
|
||||
@@ -159,7 +165,7 @@ class NomicGPT4AllPrompter(AlpacaPrompter):
|
||||
"""
|
||||
|
||||
|
||||
class ReflectAlpacaPrompter:
|
||||
class ReflectAlpacaPrompter(Prompter):
|
||||
"""
|
||||
Prompter for ReflectAlpaca
|
||||
"""
|
||||
@@ -254,7 +260,7 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||
)
|
||||
|
||||
|
||||
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
@@ -349,7 +355,7 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedPrompter:
|
||||
class UnsupportedPrompter(Prompter):
|
||||
"""
|
||||
A dummy class for custom prompters
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
@@ -10,6 +9,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import transformers.modelcard
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
@@ -26,7 +26,7 @@ src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
configure_logging()
|
||||
LOG = logging.getLogger("axolotl.train")
|
||||
LOG = get_logger("axolotl.train")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -44,7 +44,10 @@ def train(
|
||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||
):
|
||||
# load the tokenizer first
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
main_process_only=True,
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
@@ -52,7 +55,10 @@ def train(
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
# Load the model and tokenizer
|
||||
LOG.info("loading model and (optionally) peft_config...")
|
||||
msg = "loading model"
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
LOG.debug(msg)
|
||||
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
@@ -124,6 +124,36 @@ class GPUStatsCallback(
|
||||
return control
|
||||
|
||||
|
||||
class LossWatchDogCallback(TrainerCallback):
|
||||
"""Callback to track loss and stop training if loss is too high"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.logged = False
|
||||
self.violations = 0
|
||||
self.threshold = cfg.loss_watchdog_threshold
|
||||
self.patience = cfg.loss_watchdog_patience or 3
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
_args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**_kwargs,
|
||||
):
|
||||
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
|
||||
if state.log_history[-1]["loss"] > self.threshold:
|
||||
self.violations += 1
|
||||
if self.violations >= self.patience:
|
||||
LOG.warning(
|
||||
"Loss is too high, stopping training (loss_watchdog_threshold)"
|
||||
)
|
||||
control.should_training_stop = True
|
||||
else:
|
||||
self.violations = 0
|
||||
return control
|
||||
|
||||
|
||||
def bench_eval_callback_factory(trainer, tokenizer):
|
||||
accuracy = evaluate.load("accuracy")
|
||||
abcd_idx = [
|
||||
|
||||
@@ -119,3 +119,30 @@ class DataCollatorForSeq2Seq:
|
||||
features["decoder_input_ids"] = decoder_input_ids
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
Collator for multipack specific to the using the BatchSampler
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
chunked_data = {}
|
||||
for feature in features[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(1) * np.array(item[feature])
|
||||
for item in features
|
||||
if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
@@ -27,7 +27,7 @@ def choose_device(cfg):
|
||||
|
||||
cfg.device = get_device()
|
||||
if cfg.world_size == 1:
|
||||
cfg.device_map = "auto"
|
||||
cfg.device_map = cfg.device_map or "auto"
|
||||
else:
|
||||
if cfg.device.startswith("cuda"):
|
||||
cfg.device_map = {"": torch.cuda.current_device()}
|
||||
@@ -122,6 +122,19 @@ def normalize_config(cfg):
|
||||
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
||||
)
|
||||
|
||||
cfg.is_qwen_derived_model = (
|
||||
(
|
||||
hasattr(model_config, "model_type")
|
||||
and model_config.model_type
|
||||
in [
|
||||
"qwen",
|
||||
]
|
||||
)
|
||||
or cfg.is_qwen_derived_model
|
||||
or "qwen" in cfg.base_model.lower()
|
||||
or (cfg.model_type and "qwen" in cfg.model_type.lower())
|
||||
)
|
||||
|
||||
if isinstance(cfg.learning_rate, str):
|
||||
cfg.learning_rate = float(cfg.learning_rate)
|
||||
|
||||
@@ -165,7 +178,11 @@ def validate_config(cfg):
|
||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||
)
|
||||
if cfg.eval_batch_size != cfg.micro_batch_size:
|
||||
if (
|
||||
cfg.eval_batch_size
|
||||
and cfg.micro_batch_size
|
||||
and cfg.eval_batch_size != cfg.micro_batch_size
|
||||
):
|
||||
LOG.warning(
|
||||
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
||||
)
|
||||
@@ -369,6 +386,24 @@ def validate_config(cfg):
|
||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||
)
|
||||
|
||||
if cfg.rope_scaling:
|
||||
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
||||
|
||||
if cfg.warmup_steps and cfg.warmup_ratio:
|
||||
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
|
||||
|
||||
if cfg.is_qwen_derived_model and cfg.gradient_checkpointing:
|
||||
LOG.warning(
|
||||
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
||||
)
|
||||
|
||||
if cfg.wandb_run_id and not cfg.wandb_name:
|
||||
cfg.wandb_name = cfg.wandb_run_id
|
||||
|
||||
LOG.warning(
|
||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||
)
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
|
||||
@@ -3,7 +3,7 @@ import functools
|
||||
import hashlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from datasets import (
|
||||
@@ -34,6 +34,7 @@ from axolotl.prompters import (
|
||||
JeopardyPrompter,
|
||||
MultipleChoiceConcisePrompter,
|
||||
MultipleChoiceExplainPrompter,
|
||||
Prompter,
|
||||
ReflectAlpacaPrompter,
|
||||
SummarizeTLDRPrompter,
|
||||
UnsupportedPrompter,
|
||||
@@ -78,19 +79,27 @@ def prepare_dataset(cfg, tokenizer):
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
cfg, train_dataset, eval_dataset, tokenizer
|
||||
)
|
||||
|
||||
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||
if total_eval_steps == 0:
|
||||
raise ValueError(
|
||||
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
|
||||
)
|
||||
|
||||
if cfg.max_steps:
|
||||
total_num_steps = min(
|
||||
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
||||
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||
)
|
||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||
else:
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
||||
return train_dataset, eval_dataset, total_num_steps, prompters
|
||||
|
||||
|
||||
def load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, default_dataset_prepared_path
|
||||
) -> DatasetDict:
|
||||
) -> Tuple[DatasetDict, List[Prompter]]:
|
||||
tokenizer_name = tokenizer.__class__.__name__
|
||||
ds_hash = str(
|
||||
md5(
|
||||
@@ -98,7 +107,12 @@ def load_tokenized_prepared_datasets(
|
||||
str(cfg.sequence_len)
|
||||
+ "@"
|
||||
+ "|".join(
|
||||
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||
sorted(
|
||||
[
|
||||
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
|
||||
for d in cfg.datasets
|
||||
]
|
||||
)
|
||||
)
|
||||
+ "|"
|
||||
+ tokenizer_name
|
||||
@@ -164,6 +178,66 @@ def load_tokenized_prepared_datasets(
|
||||
except (FileNotFoundError, ConnectionError):
|
||||
pass
|
||||
|
||||
ds_from_cloud = False
|
||||
storage_options = {}
|
||||
remote_file_system = None
|
||||
if config_dataset.path.startswith("s3://"):
|
||||
try:
|
||||
import aiobotocore.session # type: ignore
|
||||
import s3fs # type: ignore
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"s3:// paths require aiobotocore and s3fs to be installed"
|
||||
) from exc
|
||||
|
||||
# Takes credentials from ~/.aws/credentials for default profile
|
||||
s3_session = aiobotocore.session.AioSession(profile="default")
|
||||
storage_options = {"session": s3_session}
|
||||
remote_file_system = s3fs.S3FileSystem(**storage_options)
|
||||
elif config_dataset.path.startswith(
|
||||
"gs://"
|
||||
) or config_dataset.path.startswith("gcs://"):
|
||||
try:
|
||||
import gcsfs # type: ignore
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"gs:// or gcs:// paths require gcsfs to be installed"
|
||||
) from exc
|
||||
|
||||
# gcsfs will use default credentials from the environment else anon
|
||||
# https://gcsfs.readthedocs.io/en/latest/#credentials
|
||||
storage_options = {"token": None}
|
||||
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
|
||||
# TODO: Figure out how to get auth creds passed
|
||||
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
|
||||
# try:
|
||||
# import adlfs
|
||||
# except ImportError as exc:
|
||||
# raise ImportError(
|
||||
# "adl:// or abfs:// paths require adlfs to be installed"
|
||||
# ) from exc
|
||||
|
||||
# # Gen 1
|
||||
# storage_options = {
|
||||
# "tenant_id": TENANT_ID,
|
||||
# "client_id": CLIENT_ID,
|
||||
# "client_secret": CLIENT_SECRET,
|
||||
# }
|
||||
# # Gen 2
|
||||
# storage_options = {
|
||||
# "account_name": ACCOUNT_NAME,
|
||||
# "account_key": ACCOUNT_KEY,
|
||||
# }
|
||||
|
||||
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
|
||||
try:
|
||||
if remote_file_system and remote_file_system.exists(
|
||||
config_dataset.path
|
||||
):
|
||||
ds_from_cloud = True
|
||||
except (FileNotFoundError, ConnectionError):
|
||||
pass
|
||||
|
||||
# prefer local dataset, even if hub exists
|
||||
local_path = Path(config_dataset.path)
|
||||
if local_path.exists():
|
||||
@@ -177,17 +251,8 @@ def load_tokenized_prepared_datasets(
|
||||
split=None,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds_type = "json"
|
||||
if config_dataset.ds_type:
|
||||
ds_type = config_dataset.ds_type
|
||||
elif ".parquet" in config_dataset.path:
|
||||
ds_type = "parquet"
|
||||
elif ".arrow" in config_dataset.path:
|
||||
ds_type = "arrow"
|
||||
elif ".csv" in config_dataset.path:
|
||||
ds_type = "csv"
|
||||
elif ".txt" in config_dataset.path:
|
||||
ds_type = "text"
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
@@ -207,6 +272,22 @@ def load_tokenized_prepared_datasets(
|
||||
data_files=config_dataset.data_files,
|
||||
token=use_auth_token,
|
||||
)
|
||||
elif ds_from_cloud and remote_file_system:
|
||||
if remote_file_system.isdir(config_dataset.path):
|
||||
ds = load_from_disk(
|
||||
config_dataset.path,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
elif remote_file_system.isfile(config_dataset.path):
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
else:
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
fp = hf_hub_download(
|
||||
@@ -298,11 +379,29 @@ def load_tokenized_prepared_datasets(
|
||||
return dataset, prompters
|
||||
|
||||
|
||||
def get_ds_type(config_dataset: DictDefault):
|
||||
"""
|
||||
Get the dataset type from the path if it's not specified
|
||||
"""
|
||||
ds_type = "json"
|
||||
if config_dataset.ds_type:
|
||||
ds_type = config_dataset.ds_type
|
||||
elif ".parquet" in config_dataset.path:
|
||||
ds_type = "parquet"
|
||||
elif ".arrow" in config_dataset.path:
|
||||
ds_type = "arrow"
|
||||
elif ".csv" in config_dataset.path:
|
||||
ds_type = "csv"
|
||||
elif ".txt" in config_dataset.path:
|
||||
ds_type = "text"
|
||||
return ds_type
|
||||
|
||||
|
||||
def load_prepare_datasets(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
cfg,
|
||||
default_dataset_prepared_path,
|
||||
) -> Tuple[Dataset, Dataset, List[Any]]:
|
||||
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
||||
max_packed_sequence_len = (
|
||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||
)
|
||||
@@ -311,7 +410,7 @@ def load_prepare_datasets(
|
||||
) # make sure we don't accidentally set it larger than sequence_len
|
||||
|
||||
tokenizer_name = tokenizer.__class__.__name__
|
||||
prompters = []
|
||||
prompters: List[Prompter] = []
|
||||
if cfg.max_packed_sequence_len is not None:
|
||||
# see if we can go ahead and load the stacked dataset
|
||||
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
||||
@@ -445,14 +544,13 @@ def load_prepare_datasets(
|
||||
train_fingerprint = md5(to_hash_train)
|
||||
test_fingerprint = md5(to_hash_test)
|
||||
|
||||
with zero_first(is_main_process()):
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed or 42,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed or 42,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
@@ -482,10 +580,14 @@ def get_dataset_wrapper(
|
||||
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
||||
)
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
elif d_base_type == "alpaca":
|
||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||
@@ -494,7 +596,9 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "explainchoice":
|
||||
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
||||
@@ -504,7 +608,9 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "concisechoice":
|
||||
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
||||
@@ -514,7 +620,9 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "summarizetldr":
|
||||
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
||||
@@ -524,7 +632,9 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "jeopardy":
|
||||
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
||||
@@ -534,7 +644,9 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "oasst":
|
||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||
@@ -544,7 +656,9 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "gpteacher":
|
||||
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
||||
@@ -554,7 +668,9 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "reflection":
|
||||
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
||||
@@ -564,7 +680,9 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
else:
|
||||
suffix = ""
|
||||
|
||||
@@ -1,342 +0,0 @@
|
||||
# pylint: skip-file
|
||||
import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, List, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from torch.utils.data import DistributedSampler, Sampler
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.dataloader")
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_check(a: np.ndarray, c: int, n: int):
|
||||
# First-fit-decreasing bin packing
|
||||
# Check if a[] could fit in n bins with capacity c
|
||||
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
||||
|
||||
a = np.sort(a)[::-1]
|
||||
bins = np.full((n,), c, dtype=a.dtype)
|
||||
for size in a:
|
||||
not_found = True
|
||||
for idx in range(n):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
not_found = False
|
||||
break
|
||||
|
||||
if not_found:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
||||
# First-fit-decreasing bin packing (with result return)
|
||||
|
||||
indices = np.argsort(a)[::-1]
|
||||
a = a[indices]
|
||||
|
||||
bins: List[Any] = []
|
||||
bins_result: List[Any] = []
|
||||
for a_id, size in enumerate(a):
|
||||
add_new = True
|
||||
for idx in range(len(bins)):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
bins_result[idx].append(indices[a_id] + start_index)
|
||||
add_new = False
|
||||
break
|
||||
|
||||
if add_new:
|
||||
bins.append(c - size)
|
||||
bins_result.append([indices[a_id] + start_index])
|
||||
|
||||
return bins_result, len(a)
|
||||
|
||||
|
||||
@numba.njit
|
||||
def allocate(
|
||||
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
||||
):
|
||||
"""
|
||||
:param lengths: array of lengths of each sample
|
||||
:param lengths_cumsum: cumulative sum of consecutive lengths
|
||||
:param rank: rank for this process
|
||||
:param c: length of tokens per batch
|
||||
:param n: number of ranks
|
||||
:return:
|
||||
"""
|
||||
# Dynamic batch allocator, similar to Multifit
|
||||
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
||||
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
||||
|
||||
s = 0
|
||||
start_index = 0
|
||||
result = []
|
||||
result_totseqs = []
|
||||
|
||||
while True:
|
||||
# binary search [left, right)
|
||||
left = 1
|
||||
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
||||
|
||||
while right - left > 1:
|
||||
mid = (left + right) // 2
|
||||
if ffd_check(lengths[start_index : start_index + mid], c, n):
|
||||
left = mid
|
||||
else:
|
||||
right = mid
|
||||
|
||||
# use length left
|
||||
batch, tot_seqs = ffd_with_result(
|
||||
lengths[start_index : start_index + left], c, start_index
|
||||
)
|
||||
if len(batch) < n:
|
||||
break
|
||||
|
||||
start_index += left
|
||||
s = lengths_cumsum[start_index - 1]
|
||||
|
||||
# add local rank
|
||||
result.append(batch[rank])
|
||||
# add total seqs for all ranks
|
||||
result_totseqs.append(tot_seqs)
|
||||
# yield batch[rank], tot_seqs, s, len(result) * c * n
|
||||
return result, result_totseqs, s, len(result) * c * n
|
||||
|
||||
|
||||
def chunk(iterable, n):
|
||||
"""
|
||||
Chunk data into tuples of length n
|
||||
"""
|
||||
# batched('ABCDEFG', 3) --> ABC DEF G
|
||||
if n < 1:
|
||||
raise ValueError("n must be at least one")
|
||||
it = iter(iterable)
|
||||
while batch := tuple(itertools.islice(it, n)):
|
||||
yield batch
|
||||
|
||||
|
||||
def hash_indices(lst: List[int]) -> str:
|
||||
# Convert the list of integers to a string representation
|
||||
concatenated = ",".join(map(str, lst))
|
||||
|
||||
# Generate the hash
|
||||
sha256 = hashlib.sha256()
|
||||
sha256.update(concatenated.encode())
|
||||
|
||||
return sha256.hexdigest()
|
||||
|
||||
|
||||
class MultipackDistributedDataloader:
|
||||
"""Unpadded data loading using Multipack.
|
||||
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
|
||||
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Any,
|
||||
collate_fn: Callable,
|
||||
seq_max_length: int = 2048,
|
||||
batch_size: int = 1,
|
||||
sampler: Union[Sampler, DistributedSampler] = None,
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
sample_packing_seq_len_multiplier: int = 1,
|
||||
device_count: int = 1,
|
||||
prefetch_max: int = 1000,
|
||||
num_epochs: int = 1,
|
||||
):
|
||||
# Dataset
|
||||
self.dataset = dataset
|
||||
self.lengths = (
|
||||
dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
)
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
assert batch_size % sample_packing_seq_len_multiplier == 0
|
||||
assert batch_size >= sample_packing_seq_len_multiplier
|
||||
self.sampler = sampler
|
||||
self.batch_size = batch_size
|
||||
self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
|
||||
self.seq_max_length = seq_max_length
|
||||
self.batch_max_length = batch_size * seq_max_length
|
||||
self.collate_fn = collate_fn
|
||||
self.num_epochs = num_epochs
|
||||
|
||||
self.num_replicas = 1
|
||||
self.rank = 0
|
||||
|
||||
# statistics
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
self.device_count = device_count
|
||||
|
||||
# maxsize is maximum number of samples in queue
|
||||
self.prefetch_max = prefetch_max
|
||||
self.queue: Queue = Queue(maxsize=prefetch_max)
|
||||
self.thread = None
|
||||
|
||||
def _worker(self):
|
||||
LOG.info(
|
||||
f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}"
|
||||
)
|
||||
for epoch in range(self.num_epochs):
|
||||
for sample in self._internal_batch_generator():
|
||||
while True:
|
||||
if self.queue.full():
|
||||
time.sleep(1)
|
||||
else:
|
||||
break
|
||||
self.queue.put(sample)
|
||||
|
||||
# stop the queue when epoch is done
|
||||
self.queue.put(None)
|
||||
|
||||
def __iter__(self):
|
||||
if hasattr(self.sampler, "set_epoch"):
|
||||
new_epoch = self.sampler.epoch + 1
|
||||
self.sampler.set_epoch(new_epoch)
|
||||
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||
|
||||
if self.thread is None:
|
||||
self.thread = Thread(target=self._worker, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
while True:
|
||||
item = self.queue.get()
|
||||
|
||||
if item is None:
|
||||
break
|
||||
yield item
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
LOG.info("generating packed batches")
|
||||
if self.sampler:
|
||||
indices = [idx for idx in self.sampler]
|
||||
else:
|
||||
indices = range(0, len(self.dataset))
|
||||
|
||||
LOG.info(hash_indices(indices))
|
||||
lengths = self.lengths[indices]
|
||||
lengths_cumsum = np.cumsum(lengths)
|
||||
|
||||
batches, totseqs, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=self.rank,
|
||||
# c=self.batch_max_length,
|
||||
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
||||
n=self.num_replicas,
|
||||
)
|
||||
|
||||
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
||||
|
||||
# statistics
|
||||
if set_stats:
|
||||
self.eff_total_used += total_used
|
||||
self.eff_total_slots += total_slots
|
||||
|
||||
return batches, totseqs
|
||||
|
||||
def _internal_batch_generator(self):
|
||||
all_batches, _ = self.generate_batches(set_stats=True)
|
||||
features = self.dataset.features.keys()
|
||||
len_remaining = self._len_est()
|
||||
for batches in chunk(
|
||||
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
|
||||
):
|
||||
chunked_data = []
|
||||
attn_mask_cum_idx = 0
|
||||
for batch in batches:
|
||||
concatenated = {}
|
||||
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
||||
for feature in features:
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
||||
for idx, item in enumerate(batched_data)
|
||||
if feature in item
|
||||
]
|
||||
attn_mask_cum_idx += len(batched_data)
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature])
|
||||
for item in batched_data
|
||||
if feature in item
|
||||
]
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
chunked_data.append(concatenated)
|
||||
yield self.collate_fn(chunked_data)
|
||||
len_remaining -= 1
|
||||
if not len_remaining:
|
||||
return
|
||||
# yield a no-op for cases where we don't have any data left to pack
|
||||
for i in range(0, len_remaining):
|
||||
yield self.collate_fn(
|
||||
[
|
||||
{
|
||||
"input_ids": [0],
|
||||
"labels": [-100],
|
||||
"attention_mask": [True],
|
||||
"position_ids": [0],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def _len_est(self):
|
||||
lengths_sum = np.sum(self.lengths)
|
||||
lengths_sum_per_device = lengths_sum // self.device_count
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||
)
|
||||
|
||||
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||
return (
|
||||
math.floor(
|
||||
0.99
|
||||
* lengths_sum_per_device
|
||||
/ self.packing_efficiency_estimate
|
||||
// self.seq_max_length
|
||||
// self.batch_size
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
|
||||
# the same share of total tokens
|
||||
# if not self.eff_total_used:
|
||||
# batches, _ = self.generate_batches(set_stats=True)
|
||||
# LOG.info(
|
||||
# f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
# f"actual packing efficiency: {self.efficiency()}"
|
||||
# )
|
||||
return max(1, self._len_est())
|
||||
|
||||
def len_w_stats(self):
|
||||
if not self.eff_total_used:
|
||||
batches, _ = self.generate_batches(set_stats=True)
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
f"actual packing efficiency: {self.efficiency()}"
|
||||
)
|
||||
return max(1, self._len_est())
|
||||
|
||||
def efficiency(self):
|
||||
return self.eff_total_used / self.eff_total_slots
|
||||
@@ -50,6 +50,17 @@ def get_world_size():
|
||||
return int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def zero_only():
|
||||
"""
|
||||
Context manager that only runs the enclosed block on the main rank.
|
||||
"""
|
||||
if is_main_process():
|
||||
yield
|
||||
else:
|
||||
yield None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def zero_first(is_main):
|
||||
"""
|
||||
|
||||
@@ -17,7 +17,6 @@ from transformers import ( # noqa: F401
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
GPTQConfig,
|
||||
LlamaConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
@@ -29,12 +28,40 @@ from axolotl.utils.dict import DictDefault
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
||||
quant_config_exists = hasattr(model_config, "quantization_config")
|
||||
quant_config_method_is_gptq = (
|
||||
quant_config_exists
|
||||
and "quant_method" in model_config.quantization_config
|
||||
and model_config.quantization_config["quant_method"] == "gptq"
|
||||
)
|
||||
|
||||
if cfg.gptq and not quant_config_method_is_gptq:
|
||||
raise ValueError(
|
||||
"model_config.quantization_config is not set or quant_method is not set to gptq. "
|
||||
"Please make sure to point to a GPTQ model."
|
||||
)
|
||||
|
||||
if not cfg.gptq and quant_config_exists:
|
||||
raise ValueError(
|
||||
"model_config.quantization_config is set but `gptq` flag is not. "
|
||||
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
||||
)
|
||||
|
||||
|
||||
def load_model_config(cfg):
|
||||
model_config_name = cfg.base_model_config or cfg.base_model
|
||||
trust_remote_code = cfg.trust_remote_code is True
|
||||
return AutoConfig.from_pretrained(
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_config_name, trust_remote_code=trust_remote_code
|
||||
)
|
||||
if cfg.model_config:
|
||||
for key, val in cfg.model_config.items():
|
||||
setattr(model_config, key, val)
|
||||
|
||||
check_model_config(cfg, model_config)
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
def load_tokenizer(cfg):
|
||||
@@ -51,7 +78,7 @@ def load_tokenizer(cfg):
|
||||
if cfg.tokenizer_type:
|
||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||
|
||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
|
||||
tokenizer = tokenizer_cls.from_pretrained(
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
@@ -72,11 +99,6 @@ def load_tokenizer(cfg):
|
||||
# set a pad_token, but use eos_token so we don't add a new token
|
||||
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
||||
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -85,6 +107,18 @@ def load_tokenizer(cfg):
|
||||
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# Qwen base only has single token, so we need to set the special tokens
|
||||
if cfg.is_qwen_derived_model:
|
||||
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
|
||||
for attr_name in token_ids:
|
||||
if getattr(tokenizer, attr_name) is None:
|
||||
setattr(tokenizer, attr_name, tokenizer.eod_id)
|
||||
|
||||
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
|
||||
for attr_name in token_names:
|
||||
if getattr(tokenizer, attr_name) is None:
|
||||
setattr(tokenizer, attr_name, "<|endoftext|>")
|
||||
|
||||
if cfg.special_tokens:
|
||||
for k, val in cfg.special_tokens.items():
|
||||
tokenizer.add_special_tokens(
|
||||
@@ -98,6 +132,11 @@ def load_tokenizer(cfg):
|
||||
]
|
||||
)
|
||||
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@@ -110,7 +149,6 @@ def load_model(
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
base_model = cfg.base_model
|
||||
base_model_config = cfg.base_model_config
|
||||
model_type = cfg.model_type
|
||||
model_config = load_model_config(cfg)
|
||||
|
||||
@@ -201,6 +239,7 @@ def load_model(
|
||||
model_kwargs = {}
|
||||
|
||||
model_kwargs["device_map"] = cfg.device_map
|
||||
model_kwargs["max_memory"] = cfg.max_memory
|
||||
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
||||
|
||||
if cfg.model_revision:
|
||||
@@ -238,16 +277,9 @@ def load_model(
|
||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
config_kwargs = {}
|
||||
if cfg.rope_scaling:
|
||||
config_kwargs["rope_scaling"] = cfg.rope_scaling
|
||||
config = LlamaConfig.from_pretrained(
|
||||
base_model_config,
|
||||
**config_kwargs,
|
||||
)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
config=model_config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
**model_kwargs,
|
||||
@@ -292,10 +324,10 @@ def load_model(
|
||||
# device=cfg.device,
|
||||
# )
|
||||
# model.train() # sets to train instead of eval mode
|
||||
elif model_type == "MixFormerSequentialForCausalLM":
|
||||
from axolotl.models.phi import MixFormerSequentialForCausalLM
|
||||
elif model_type == "PhiForCausalLM":
|
||||
from axolotl.models.phi import PhiForCausalLM
|
||||
|
||||
model = MixFormerSequentialForCausalLM.from_pretrained(
|
||||
model = PhiForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
@@ -305,66 +337,55 @@ def load_model(
|
||||
if cfg.gptq:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
model = getattr(transformers, model_type).from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(
|
||||
base_model,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
)
|
||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
||||
# when training starts
|
||||
if (
|
||||
hasattr(config, "max_seq_len")
|
||||
and config.max_seq_len
|
||||
and cfg.sequence_len > config.max_seq_len
|
||||
hasattr(model_config, "max_seq_len")
|
||||
and model_config.max_seq_len
|
||||
and cfg.sequence_len > model_config.max_seq_len
|
||||
):
|
||||
config.max_seq_len = cfg.sequence_len
|
||||
model_config.max_seq_len = cfg.sequence_len
|
||||
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||
elif (
|
||||
hasattr(config, "max_sequence_length")
|
||||
and config.max_sequence_length
|
||||
and cfg.sequence_len > config.max_sequence_length
|
||||
hasattr(model_config, "max_sequence_length")
|
||||
and model_config.max_sequence_length
|
||||
and cfg.sequence_len > model_config.max_sequence_length
|
||||
):
|
||||
config.max_sequence_length = cfg.sequence_len
|
||||
model_config.max_sequence_length = cfg.sequence_len
|
||||
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||
if cfg.gptq:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
config=model_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
config=model_config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
except Exception as err: # pylint: disable=broad-exception-caught
|
||||
LOG.error(
|
||||
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
||||
)
|
||||
LOG.exception(err)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
raise err
|
||||
|
||||
embeddings_len = (
|
||||
math.ceil(len(tokenizer) / 32) * 32
|
||||
@@ -386,6 +407,20 @@ def load_model(
|
||||
)
|
||||
model.config.max_position_embeddings = cfg.sequence_len
|
||||
|
||||
if (
|
||||
hasattr(model.config, "bos_token_id")
|
||||
and model.config.bos_token_id
|
||||
and model.config.bos_token_id != tokenizer.bos_token_id
|
||||
):
|
||||
model.config.bos_token_id = tokenizer.bos_token_id
|
||||
|
||||
if (
|
||||
hasattr(model.config, "eos_token_id")
|
||||
and model.config.eos_token_id
|
||||
and model.config.eos_token_id != tokenizer.eos_token_id
|
||||
):
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
if model.device.type == "cuda":
|
||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||
|
||||
@@ -401,15 +436,22 @@ def load_model(
|
||||
module.to(torch.float32)
|
||||
|
||||
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||
skip_prepare_model_for_kbit_training = False
|
||||
|
||||
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if (cfg.adapter == "lora" and load_in_8bit) or (
|
||||
cfg.adapter == "qlora" and cfg.load_in_4bit
|
||||
):
|
||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||
if cfg.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||
)
|
||||
if not skip_prepare_model_for_kbit_training:
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||
)
|
||||
needs_fa2_dtype = True
|
||||
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||
@@ -428,14 +470,7 @@ def load_model(
|
||||
if cfg.ddp and not load_in_8bit:
|
||||
model.to(f"cuda:{cfg.local_rank}")
|
||||
|
||||
if (
|
||||
torch.cuda.device_count() > 1
|
||||
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
||||
and (cfg.load_in_4bit)
|
||||
):
|
||||
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
||||
# so let's only set it for the 4bit, see
|
||||
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
||||
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
||||
setattr(model, "is_parallelizable", True)
|
||||
setattr(model, "model_parallel", True)
|
||||
|
||||
|
||||
4
src/axolotl/utils/samplers/__init__.py
Normal file
4
src/axolotl/utils/samplers/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
axolotl samplers module
|
||||
"""
|
||||
from .multipack import MultipackBatchSampler # noqa: F401
|
||||
196
src/axolotl/utils/samplers/multipack.py
Normal file
196
src/axolotl/utils/samplers/multipack.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# pylint: skip-file
|
||||
"""
|
||||
Multipack Batch Sampler
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Iterable, List, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from torch.utils.data import BatchSampler, Sampler
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_check(a: np.ndarray, c: int, n: int):
|
||||
# First-fit-decreasing bin packing
|
||||
# Check if a[] could fit in n bins with capacity c
|
||||
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
||||
|
||||
a = np.sort(a)[::-1]
|
||||
bins = np.full((n,), c, dtype=a.dtype)
|
||||
for size in a:
|
||||
not_found = True
|
||||
for idx in range(n):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
not_found = False
|
||||
break
|
||||
|
||||
if not_found:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
||||
# First-fit-decreasing bin packing (with result return)
|
||||
|
||||
indices = np.argsort(a)[::-1]
|
||||
a = a[indices]
|
||||
|
||||
bins: List[Any] = []
|
||||
bins_result: List[Any] = []
|
||||
for a_id, size in enumerate(a):
|
||||
add_new = True
|
||||
for idx in range(len(bins)):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
bins_result[idx].append(indices[a_id] + start_index)
|
||||
add_new = False
|
||||
break
|
||||
|
||||
if add_new:
|
||||
bins.append(c - size)
|
||||
bins_result.append([indices[a_id] + start_index])
|
||||
|
||||
return bins_result
|
||||
|
||||
|
||||
@numba.njit
|
||||
def allocate(
|
||||
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
||||
):
|
||||
# Dynamic batch allocator, similar to Multifit
|
||||
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
||||
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
||||
|
||||
s = 0
|
||||
start_index = 0
|
||||
result = []
|
||||
|
||||
while True:
|
||||
# binary search [l, r)
|
||||
left = 1
|
||||
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
||||
|
||||
while right - left > 1:
|
||||
mid = (left + right) // 2
|
||||
if ffd_check(lengths[start_index : start_index + mid], c, n):
|
||||
left = mid
|
||||
else:
|
||||
right = mid
|
||||
|
||||
# use length l
|
||||
batch = ffd_with_result(
|
||||
lengths[start_index : start_index + left], c, start_index
|
||||
)
|
||||
assert len(batch) <= n
|
||||
if len(batch) < n:
|
||||
break
|
||||
|
||||
start_index += left
|
||||
s = lengths_cumsum[start_index - 1]
|
||||
|
||||
# add local rank
|
||||
result.append(batch[rank])
|
||||
|
||||
return result, s, len(result) * c * n
|
||||
|
||||
|
||||
class MultipackBatchSampler(BatchSampler):
|
||||
"""
|
||||
Batch Sampler class for multipack
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampler: Union[Sampler[int], Iterable[int]],
|
||||
batch_size: int,
|
||||
drop_last: bool,
|
||||
batch_max_len: int,
|
||||
lengths: np.ndarray,
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
self.batch_size = None
|
||||
self.batch_max_len = batch_max_len
|
||||
self.lengths: np.ndarray = lengths
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
|
||||
self.epoch = 0
|
||||
|
||||
# statistics
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
self.epoch = epoch
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
indices = [idx for idx in self.sampler]
|
||||
|
||||
lengths = self.lengths[indices]
|
||||
lengths_cumsum = np.cumsum(lengths)
|
||||
|
||||
batches, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=0,
|
||||
c=self.batch_max_len,
|
||||
n=1,
|
||||
)
|
||||
|
||||
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
||||
|
||||
# statistics
|
||||
if set_stats:
|
||||
self.eff_total_used += total_used
|
||||
self.eff_total_slots += total_slots
|
||||
|
||||
return batches
|
||||
|
||||
def __iter__(self):
|
||||
batches = self.generate_batches(set_stats=True)
|
||||
return iter(batches)
|
||||
|
||||
def num_batches(self):
|
||||
batches = self.generate_batches(set_stats=True)
|
||||
return len(batches)
|
||||
|
||||
def efficiency(self):
|
||||
return self.eff_total_used / self.eff_total_slots
|
||||
|
||||
def __len__(self):
|
||||
self.num_batches()
|
||||
return self._len_est()
|
||||
|
||||
def _len_est(self):
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
lengths_sum = np.sum(self.lengths)
|
||||
lengths_sum_per_device = lengths_sum // world_size
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||
)
|
||||
|
||||
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||
return max(
|
||||
0,
|
||||
(
|
||||
world_size
|
||||
* math.floor(
|
||||
0.99
|
||||
* lengths_sum_per_device
|
||||
/ self.packing_efficiency_estimate
|
||||
// self.batch_max_len
|
||||
)
|
||||
- 1
|
||||
),
|
||||
)
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Module containing the Trainer class and related functions"""
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
@@ -9,21 +8,15 @@ from typing import List
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.distributed as dist
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import set_caching_enabled
|
||||
from torch.utils.data import DistributedSampler, RandomSampler
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder
|
||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||
from axolotl.utils.distributed import (
|
||||
is_distributed,
|
||||
is_main_process,
|
||||
reduce_and_broadcast,
|
||||
zero_first,
|
||||
)
|
||||
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
||||
from axolotl.utils.samplers import MultipackBatchSampler
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = get_logger("axolotl")
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@@ -148,30 +141,35 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
if not cfg.total_num_tokens:
|
||||
total_num_tokens = np.sum(
|
||||
train_dataset.data.column("input_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
||||
.values
|
||||
)
|
||||
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
|
||||
if update:
|
||||
cfg.total_num_tokens = total_num_tokens
|
||||
|
||||
if not cfg.total_supervised_tokens:
|
||||
total_supervised_tokens = (
|
||||
train_dataset.data.column("labels")
|
||||
.to_pandas()
|
||||
.apply(lambda x: np.sum(np.array(x) != -100))
|
||||
.sum()
|
||||
)
|
||||
LOG.debug(
|
||||
f"`total_supervised_tokens: {total_supervised_tokens}`",
|
||||
main_process_only=True,
|
||||
)
|
||||
if update:
|
||||
cfg.total_supervised_tokens = total_supervised_tokens
|
||||
|
||||
if cfg.sample_packing:
|
||||
# we have to drop anything longer then sequence len otherwise
|
||||
# flash attention with position ids fails
|
||||
if not cfg.total_num_tokens:
|
||||
LOG.info("calculating total_num_tokens")
|
||||
total_num_tokens = np.sum(
|
||||
train_dataset.data.column("input_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
||||
.values
|
||||
)
|
||||
LOG.info(f"total_num_tokens: {total_num_tokens}")
|
||||
cfg.total_num_tokens = total_num_tokens
|
||||
|
||||
if not cfg.total_supervised_tokens:
|
||||
total_supervised_tokens = (
|
||||
train_dataset.data.column("labels")
|
||||
.to_pandas()
|
||||
.apply(lambda x: np.sum(np.array(x) != -100))
|
||||
.sum()
|
||||
)
|
||||
LOG.info(f"`total_supervised_tokens: {total_supervised_tokens}`")
|
||||
cfg.total_supervised_tokens = total_supervised_tokens
|
||||
|
||||
if cfg.sample_packing_eff_est:
|
||||
total_num_steps = (
|
||||
@@ -189,41 +187,41 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
)
|
||||
* cfg.num_epochs
|
||||
)
|
||||
LOG.info(
|
||||
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
|
||||
LOG.debug(
|
||||
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}",
|
||||
main_process_only=True,
|
||||
)
|
||||
else:
|
||||
if cfg.world_size > 1 and is_distributed():
|
||||
sampler = DistributedSampler(
|
||||
train_dataset,
|
||||
num_replicas=cfg.world_size,
|
||||
rank=dist.get_rank(),
|
||||
seed=cfg.seed or 42,
|
||||
)
|
||||
else:
|
||||
sampler = RandomSampler(train_dataset)
|
||||
|
||||
data_loader = MultipackDistributedDataloader(
|
||||
train_dataset,
|
||||
sampler = MultipackBatchSampler(
|
||||
sampler=RandomSampler(train_dataset),
|
||||
batch_size=cfg.micro_batch_size,
|
||||
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
|
||||
collate_fn=DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
drop_last=True,
|
||||
batch_max_len=cfg.micro_batch_size
|
||||
* (cfg.max_packed_sequence_len or cfg.sequence_len),
|
||||
lengths=(
|
||||
train_dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
),
|
||||
sampler=sampler,
|
||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
num_epochs=cfg.num_epochs,
|
||||
)
|
||||
data_loader_len = data_loader.len_w_stats()
|
||||
actual_eff = data_loader.efficiency()
|
||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||
|
||||
data_loader = DataLoader(
|
||||
train_dataset.remove_columns(["length"]),
|
||||
batch_sampler=sampler,
|
||||
)
|
||||
data_loader_len = len(data_loader)
|
||||
actual_eff = sampler.efficiency()
|
||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||
# on the agreed on value for sample_packing_eff_est
|
||||
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
||||
total_num_steps = int(
|
||||
math.floor(
|
||||
data_loader_len
|
||||
* cfg.num_epochs
|
||||
/ int(os.environ.get("WORLD_SIZE", 1))
|
||||
)
|
||||
)
|
||||
|
||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||
@@ -236,13 +234,22 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
sample_packing_eff_est = (
|
||||
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
||||
)
|
||||
cfg.sample_packing_eff_est = sample_packing_eff_est
|
||||
LOG.info(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
|
||||
if update:
|
||||
cfg.sample_packing_eff_est = sample_packing_eff_est
|
||||
LOG.debug(
|
||||
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
|
||||
main_process_only=True,
|
||||
)
|
||||
else:
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
math.ceil(
|
||||
len(train_dataset)
|
||||
* cfg.num_epochs
|
||||
/ int(os.environ.get("WORLD_SIZE", 1))
|
||||
/ cfg.batch_size
|
||||
)
|
||||
)
|
||||
LOG.info(f"total_num_steps: {total_num_steps}")
|
||||
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
||||
return total_num_steps
|
||||
|
||||
|
||||
@@ -260,12 +267,14 @@ def setup_fsdp_envs(cfg):
|
||||
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
||||
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||
def prepare_optim_env(cfg):
|
||||
if cfg.fsdp:
|
||||
setup_fsdp_envs(cfg)
|
||||
elif cfg.deepspeed:
|
||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
||||
trainer_builder.train_dataset = train_dataset
|
||||
trainer_builder.eval_dataset = eval_dataset
|
||||
|
||||
@@ -2,20 +2,20 @@
|
||||
|
||||
import os
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
def setup_wandb_env_vars(cfg):
|
||||
if cfg.wandb_mode and cfg.wandb_mode == "offline":
|
||||
os.environ["WANDB_MODE"] = cfg.wandb_mode
|
||||
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
||||
|
||||
def setup_wandb_env_vars(cfg: DictDefault):
|
||||
for key in cfg.keys():
|
||||
if key.startswith("wandb_"):
|
||||
value = cfg.get(key, "")
|
||||
|
||||
if value and isinstance(value, str) and len(value) > 0:
|
||||
os.environ[key.upper()] = value
|
||||
|
||||
# Enable wandb if project name is present
|
||||
if cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||
cfg.use_wandb = True
|
||||
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
|
||||
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
|
||||
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
||||
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
||||
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
||||
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
||||
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
||||
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
||||
os.environ.pop("WANDB_DISABLED", None) # Remove if present
|
||||
else:
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
0
tests/e2e/__init__.py
Normal file
0
tests/e2e/__init__.py
Normal file
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
@@ -16,6 +15,8 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -25,9 +26,9 @@ class TestFusedLlama(unittest.TestCase):
|
||||
Test case for Llama models using Fused layers
|
||||
"""
|
||||
|
||||
def test_fft_packing(self):
|
||||
@with_temp_dir
|
||||
def test_fft_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
@@ -51,7 +52,7 @@ class TestFusedLlama(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -69,4 +70,4 @@ class TestFusedLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
@@ -14,6 +13,8 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -23,9 +24,9 @@ class TestLoraLlama(unittest.TestCase):
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
def test_lora(self):
|
||||
@with_temp_dir
|
||||
def test_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
@@ -52,7 +53,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -63,11 +64,11 @@ class TestLoraLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
def test_lora_packing(self):
|
||||
@with_temp_dir
|
||||
def test_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
@@ -96,10 +97,11 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
@@ -107,11 +109,11 @@ class TestLoraLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
def test_lora_gptq(self):
|
||||
@with_temp_dir
|
||||
def test_lora_gptq(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
||||
@@ -144,7 +146,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"save_steps": 0.5,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -155,4 +157,4 @@ class TestLoraLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
@@ -16,6 +15,8 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -25,9 +26,9 @@ class TestMistral(unittest.TestCase):
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
def test_lora(self):
|
||||
@with_temp_dir
|
||||
def test_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
@@ -54,7 +55,7 @@ class TestMistral(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -68,11 +69,11 @@ class TestMistral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
def test_ft(self):
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
@@ -93,7 +94,7 @@ class TestMistral(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -111,4 +112,4 @@ class TestMistral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
@@ -16,6 +15,8 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -25,9 +26,9 @@ class TestMistral(unittest.TestCase):
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
def test_lora_packing(self):
|
||||
@with_temp_dir
|
||||
def test_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
@@ -55,7 +56,7 @@ class TestMistral(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -69,11 +70,11 @@ class TestMistral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
def test_ft_packing(self):
|
||||
@with_temp_dir
|
||||
def test_ft_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
@@ -95,7 +96,7 @@ class TestMistral(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -113,4 +114,4 @@ class TestMistral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
|
||||
@@ -4,8 +4,8 @@ E2E tests for lora llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -13,6 +13,8 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -22,13 +24,14 @@ class TestPhi(unittest.TestCase):
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
def test_ft(self):
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
"trust_remote_code": True,
|
||||
"model_type": "MixFormerSequentialForCausalLM",
|
||||
"model_type": "PhiForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 512,
|
||||
"sample_packing": False,
|
||||
@@ -52,7 +55,7 @@ class TestPhi(unittest.TestCase):
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": tempfile.mkdtemp(),
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -64,14 +67,16 @@ class TestPhi(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
|
||||
def test_ft_packed(self):
|
||||
@with_temp_dir
|
||||
def test_ft_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
"trust_remote_code": True,
|
||||
"model_type": "MixFormerSequentialForCausalLM",
|
||||
"model_type": "PhiForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 512,
|
||||
"sample_packing": True,
|
||||
@@ -95,7 +100,7 @@ class TestPhi(unittest.TestCase):
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": tempfile.mkdtemp(),
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -107,3 +112,4 @@ class TestPhi(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
|
||||
95
tests/e2e/test_resume.py
Normal file
95
tests/e2e/test_resume.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
E2E tests for resuming training
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import most_recent_subdir, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestResumeLlama(unittest.TestCase):
|
||||
"""
|
||||
Test case for resuming training of llama models
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_resume_qlora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
"flash_attention": True,
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 64,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "vicgalle/alpaca-gpt4",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_steps": 10,
|
||||
"save_total_limit": 5,
|
||||
"max_steps": 40,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
resume_cfg = cfg | DictDefault(
|
||||
{
|
||||
"resume_from_checkpoint": f"{temp_dir}/checkpoint-30/",
|
||||
}
|
||||
)
|
||||
normalize_config(resume_cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
|
||||
train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
|
||||
cmd = f"tensorboard --inspect --logdir {tb_log_path_1}"
|
||||
res = subprocess.run(
|
||||
cmd, shell=True, text=True, capture_output=True, check=True
|
||||
)
|
||||
pattern = r"first_step\s+(\d+)"
|
||||
first_steps = int(re.findall(pattern, res.stdout)[0])
|
||||
assert first_steps == 31
|
||||
33
tests/e2e/utils.py
Normal file
33
tests/e2e/utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
helper utils for tests
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def with_temp_dir(test_func):
|
||||
@wraps(test_func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Create a temporary directory
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
# Pass the temporary directory to the test function
|
||||
test_func(*args, temp_dir=temp_dir, **kwargs)
|
||||
finally:
|
||||
# Clean up the directory after the test
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def most_recent_subdir(path):
|
||||
base_path = Path(path)
|
||||
subdirectories = [d for d in base_path.iterdir() if d.is_dir()]
|
||||
if not subdirectories:
|
||||
return None
|
||||
subdir = max(subdirectories, key=os.path.getctime)
|
||||
|
||||
return subdir
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Module for testing the validation module"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
@@ -8,6 +9,7 @@ import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
|
||||
class ValidationTest(unittest.TestCase):
|
||||
@@ -649,3 +651,113 @@ class ValidationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_warmup_step_no_conflict(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"warmup_steps": 10,
|
||||
"warmup_ratio": 0.1,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*warmup_steps and warmup_ratio are mutually exclusive*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"warmup_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"warmup_ratio": 0.1,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
|
||||
class ValidationWandbTest(ValidationTest):
|
||||
"""
|
||||
Validation test for wandb
|
||||
"""
|
||||
|
||||
def test_wandb_set_run_id_to_name(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"wandb_run_id": "foo",
|
||||
}
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||
in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"wandb_name": "foo",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
|
||||
|
||||
def test_wandb_sets_env(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"wandb_project": "foo",
|
||||
"wandb_name": "bar",
|
||||
"wandb_run_id": "bat",
|
||||
"wandb_entity": "baz",
|
||||
"wandb_mode": "online",
|
||||
"wandb_watch": "false",
|
||||
"wandb_log_model": "checkpoint",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
|
||||
assert os.environ.get("WANDB_PROJECT", "") == "foo"
|
||||
assert os.environ.get("WANDB_NAME", "") == "bar"
|
||||
assert os.environ.get("WANDB_RUN_ID", "") == "bat"
|
||||
assert os.environ.get("WANDB_ENTITY", "") == "baz"
|
||||
assert os.environ.get("WANDB_MODE", "") == "online"
|
||||
assert os.environ.get("WANDB_WATCH", "") == "false"
|
||||
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
|
||||
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
||||
|
||||
def test_wandb_set_disabled(self):
|
||||
cfg = DictDefault({})
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
|
||||
assert os.environ.get("WANDB_DISABLED", "") == "true"
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"wandb_project": "foo",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
|
||||
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
||||
|
||||
Reference in New Issue
Block a user