Compare commits
28 Commits
llava-trai
...
fp8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8836986a92 | ||
|
|
105d0b350b | ||
|
|
f544ab2bed | ||
|
|
641e6f7e51 | ||
|
|
6dc68a653f | ||
|
|
7de6a5639c | ||
|
|
c74f045ba7 | ||
|
|
0402d19759 | ||
|
|
b2430ce670 | ||
|
|
4c834bf25d | ||
|
|
8056ecd30e | ||
|
|
738a057674 | ||
|
|
cdc71f73c8 | ||
|
|
6459ac7357 | ||
|
|
964d858da0 | ||
|
|
10388a8daf | ||
|
|
9f7e8a971d | ||
|
|
637ed095a0 | ||
|
|
827ec3d274 | ||
|
|
8b79ff0e94 | ||
|
|
0800885e2f | ||
|
|
d3193beac3 | ||
|
|
2e71ff03a6 | ||
|
|
facc49f32b | ||
|
|
e50ab072e2 | ||
|
|
05bd6f1122 | ||
|
|
20aa4b57d2 | ||
|
|
11d1d607db |
99
README.md
99
README.md
@@ -32,7 +32,6 @@ Features:
|
||||
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
||||
- [Config](#config)
|
||||
- [Train](#train)
|
||||
- [Training w/ Deepspeed](#training-with-deepspeed)
|
||||
- [Inference](#inference)
|
||||
- [Merge LORA to Base](#merge-lora-to-base)
|
||||
- [Common Errors](#common-errors-)
|
||||
@@ -75,6 +74,7 @@ Features:
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||
|
||||
|
||||
## Quickstart ⚡
|
||||
@@ -97,6 +97,10 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||
# inference
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--lora_model_dir="./lora-out"
|
||||
|
||||
# gradio
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--lora_model_dir="./lora-out" --gradio
|
||||
```
|
||||
|
||||
## Installation
|
||||
@@ -115,6 +119,25 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Docker advanced</summary>
|
||||
|
||||
A more powerful Docker command to run would be this:
|
||||
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
```
|
||||
|
||||
It additionally:
|
||||
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
|
||||
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
|
||||
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
|
||||
|
||||
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
|
||||
|
||||
</details>
|
||||
|
||||
#### Conda/Pip venv
|
||||
1. Install python >=**3.9**
|
||||
|
||||
@@ -356,6 +379,13 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- typescript
|
||||
type: ... # unimplemented custom format
|
||||
|
||||
# fastchat conversation
|
||||
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
datasets:
|
||||
- path: ...
|
||||
type: sharegpt
|
||||
conversation: chatml
|
||||
|
||||
# local
|
||||
datasets:
|
||||
- path: data.jsonl # or json
|
||||
@@ -394,7 +424,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
|
||||
<details>
|
||||
|
||||
<summary>All yaml options</summary>
|
||||
<summary>All yaml options (click me)</summary>
|
||||
|
||||
```yaml
|
||||
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
@@ -461,7 +491,9 @@ datasets:
|
||||
data_files: # Optional[str] path to source data files
|
||||
shards: # Optional[int] number of shards to split data into
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||
|
||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
|
||||
# Custom user prompt
|
||||
- path: repo
|
||||
@@ -591,14 +623,14 @@ gradient_accumulation_steps: 1
|
||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||
micro_batch_size: 2
|
||||
eval_batch_size:
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
lr_quadratic_warmup:
|
||||
logging_steps:
|
||||
save_strategy: # Set to `no` to skip checkpoint saves
|
||||
save_steps: # Leave empty to save at each epoch
|
||||
eval_steps: # Leave empty to eval at each epoch
|
||||
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
||||
save_total_limit: # Checkpoints saved at a time
|
||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||
# if both are set, num_epochs will not be guaranteed.
|
||||
@@ -815,14 +847,41 @@ Run
|
||||
accelerate launch -m axolotl.cli.train your_config.yml
|
||||
```
|
||||
|
||||
#### Multi-GPU
|
||||
#### Preprocess dataset
|
||||
|
||||
You can optionally pre-tokenize dataset with the following before finetuning.
|
||||
This is recommended for large datasets.
|
||||
|
||||
- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
|
||||
- Use `--debug` to see preprocessed examples.
|
||||
|
||||
You can optionally pre-tokenize dataset with the following before finetuning:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only
|
||||
python -m axolotl.cli.preprocess your_config.yml
|
||||
```
|
||||
|
||||
##### Config
|
||||
#### Multi-GPU
|
||||
|
||||
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed
|
||||
is the recommended multi-GPU option currently because FSDP may experience
|
||||
[loss instability](https://github.com/huggingface/transformers/issues/26498).
|
||||
|
||||
##### DeepSpeed
|
||||
|
||||
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
||||
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
||||
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
||||
|
||||
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed/zero1.json
|
||||
```
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
||||
```
|
||||
|
||||
##### FSDP
|
||||
|
||||
- llama FSDP
|
||||
```yaml
|
||||
@@ -847,24 +906,6 @@ wandb_run_id:
|
||||
wandb_log_model:
|
||||
```
|
||||
|
||||
### Training with Deepspeed
|
||||
|
||||
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
||||
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
||||
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
||||
|
||||
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed/zero1.json
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
Pass the appropriate flag to the train command:
|
||||
@@ -882,6 +923,10 @@ Pass the appropriate flag to the train command:
|
||||
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
|
||||
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
||||
```
|
||||
-- With gradio hosting
|
||||
```bash
|
||||
python -m axolotl.cli.inference examples/your_config.yml --gradio
|
||||
```
|
||||
|
||||
Please use `--sample_packing False` if you have it on and receive the error similar to below:
|
||||
|
||||
|
||||
@@ -1,14 +1,6 @@
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 0,
|
||||
@@ -41,12 +33,13 @@
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear"
|
||||
"warmup_type": "linear",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
|
||||
@@ -21,9 +21,9 @@ WORKDIR /workspace/axolotl
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
||||
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
|
||||
else \
|
||||
pip install -e .[flash-attn]; \
|
||||
pip install -e .[deepspeed,flash-attn]; \
|
||||
fi
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
|
||||
@@ -10,8 +10,10 @@ ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
ARG PYTHON_VERSION="3.9"
|
||||
ARG PYTORCH_VERSION="2.0.1"
|
||||
ARG CUDA="118"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
||||
@@ -27,47 +29,9 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} deepspeed-kernels --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||
|
||||
FROM base-builder AS deepspeed-builder
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
|
||||
cd DeepSpeed && \
|
||||
MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 DS_BUILD_EVOFORMER_ATTN=0 python3 setup.py bdist_wheel
|
||||
|
||||
FROM base-builder AS bnb-builder
|
||||
|
||||
WORKDIR /workspace
|
||||
ARG CUDA="118"
|
||||
ENV CUDA=$CUDA
|
||||
ARG MAX_JOBS="-1"
|
||||
ENV MAX_JOBS=$MAX_JOBS
|
||||
|
||||
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
|
||||
cd bitsandbytes && \
|
||||
CUDA_VERSION=$CUDA make cuda11x && \
|
||||
python setup.py bdist_wheel
|
||||
|
||||
FROM base-builder
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN mkdir -p /workspace/builds
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
||||
|
||||
RUN mkdir -p /workspace/wheels/bitsandbytes
|
||||
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
|
||||
|
||||
RUN pip3 install wheels/deepspeed-*.whl
|
||||
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
|
||||
RUN git lfs install --skip-repo
|
||||
RUN pip3 install awscli && \
|
||||
RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
@@ -12,3 +12,7 @@ This usually happens when you run out of system RAM.
|
||||
> Exitcode -7 while using deepspeed
|
||||
|
||||
Try upgrading deepspeed w: `pip install -U deepspeed`
|
||||
|
||||
> AttributeError: 'DummyOptim' object has no attribute 'step'
|
||||
|
||||
You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
|
||||
|
||||
@@ -14,7 +14,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_prepared_run
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
|
||||
@@ -7,7 +7,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -49,7 +49,7 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -34,7 +34,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,7 +54,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -36,7 +36,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -56,7 +56,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -34,7 +34,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,7 +54,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -36,7 +36,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -56,7 +56,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -34,7 +34,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,7 +54,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -36,7 +36,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -56,7 +56,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -12,7 +12,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca:chat
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
|
||||
@@ -18,7 +18,7 @@ datasets:
|
||||
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
||||
type: "alpaca:chat"
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
@@ -53,7 +53,7 @@ output_dir: ./qlora-out
|
||||
# decrease if OOM, increase for max VRAM utilization
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
# Optimizer for QLoRA
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
|
||||
@@ -12,7 +12,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca:chat
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
|
||||
@@ -7,7 +7,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -46,7 +46,7 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -24,7 +24,7 @@ wandb_log_model:
|
||||
output_dir: ./jeopardy-bot-7b
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 4096
|
||||
|
||||
@@ -15,7 +15,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 4096
|
||||
@@ -37,7 +37,7 @@ wandb_log_model:
|
||||
output_dir: ./model-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
adam_beta2: 0.95
|
||||
adam_eps: 0.00001
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -34,7 +34,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,7 +54,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -36,7 +36,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -56,7 +56,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
save_steps:
|
||||
debug:
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./relora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -40,7 +40,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -60,7 +60,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_steps: 0.05
|
||||
save_steps: 50
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -12,7 +12,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -34,7 +34,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,7 +54,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
save_steps:
|
||||
debug:
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 8192
|
||||
@@ -26,7 +26,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.000005
|
||||
@@ -46,7 +46,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -63,7 +63,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
|
||||
@@ -26,7 +26,7 @@ wandb_log_model:
|
||||
output_dir: ./mpt-alpaca-7b
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
|
||||
@@ -9,7 +9,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 1024
|
||||
|
||||
@@ -23,7 +23,7 @@ wandb_log_model:
|
||||
output_dir: ./lora-alpaca-pythia
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 4
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
learning_rate: 0.00001
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
@@ -33,5 +33,5 @@ early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
weight_decay: 0.1
|
||||
eval_steps: 20
|
||||
eval_steps: 0.05
|
||||
logging_steps: 1
|
||||
|
||||
@@ -27,7 +27,7 @@ wandb_log_model:
|
||||
output_dir: ./redpajama-alpaca-3b
|
||||
batch_size: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
|
||||
@@ -26,7 +26,7 @@ wandb_log_model:
|
||||
output_dir: ./lora-replit
|
||||
batch_size: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer:
|
||||
torchdistx_path:
|
||||
lr_scheduler:
|
||||
|
||||
@@ -16,7 +16,7 @@ datasets:
|
||||
- openassistant_best_replies_train.jsonl
|
||||
type: "completion"
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
@@ -51,7 +51,7 @@ output_dir: ./qlora-out
|
||||
# decrease if OOM, increase for max VRAM utilization
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
# Optimizer for QLoRA
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
torch==2.0.1
|
||||
auto-gptq
|
||||
auto-gptq==0.4.2
|
||||
packaging
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
peft==0.6.0
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697
|
||||
bitsandbytes>=0.41.1
|
||||
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
|
||||
@@ -17,7 +17,7 @@ sentencepiece
|
||||
wandb
|
||||
einops
|
||||
xformers>=0.0.22
|
||||
optimum
|
||||
optimum==1.13.2
|
||||
hf_transfer
|
||||
colorama
|
||||
numba
|
||||
@@ -31,3 +31,4 @@ scikit-learn==1.2.2
|
||||
pynvml
|
||||
art
|
||||
fschat==0.2.29
|
||||
gradio
|
||||
|
||||
@@ -45,8 +45,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
if parsed_cli_args.prepare_ds_only:
|
||||
return
|
||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
|
||||
@@ -6,8 +6,10 @@ import os
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
@@ -16,7 +18,7 @@ from accelerate.commands.config import config_args
|
||||
from art import text2art
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.logging_config import configure_logging
|
||||
@@ -153,6 +155,91 @@ def do_inference(
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
def do_inference_gradio(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
prompter = cli_args.prompter
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
if prompter:
|
||||
prompter_module = getattr(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
|
||||
if cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
model.set_mem_cache_args(
|
||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||
)
|
||||
|
||||
model = model.to(cfg.device)
|
||||
|
||||
def generate(instruction):
|
||||
if not instruction:
|
||||
return
|
||||
if prompter_module:
|
||||
# pylint: disable=stop-iteration-return
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
else:
|
||||
prompt = instruction.strip()
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=1024,
|
||||
temperature=0.9,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
output_scores=False,
|
||||
)
|
||||
streamer = TextIteratorStreamer(tokenizer)
|
||||
generation_kwargs = {
|
||||
"inputs": batch["input_ids"].to(cfg.device),
|
||||
"generation_config": generation_config,
|
||||
"streamer": streamer,
|
||||
}
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
all_text = ""
|
||||
|
||||
for new_text in streamer:
|
||||
all_text += new_text
|
||||
yield all_text
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate,
|
||||
inputs="textbox",
|
||||
outputs="text",
|
||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||
)
|
||||
demo.queue().launch(show_api=False, share=True)
|
||||
|
||||
|
||||
def choose_config(path: Path):
|
||||
yaml_files = list(path.glob("*.yml"))
|
||||
|
||||
@@ -222,7 +309,9 @@ def load_datasets(
|
||||
) -> TrainDatasetMeta:
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
||||
cfg, tokenizer
|
||||
)
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
@@ -238,6 +327,10 @@ def load_datasets(
|
||||
text_only=cli_args.debug_text_only,
|
||||
)
|
||||
|
||||
LOG.info("printing prompters...")
|
||||
for prompter in prompters:
|
||||
LOG.info(prompter)
|
||||
|
||||
return TrainDatasetMeta(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
|
||||
@@ -6,11 +6,16 @@ from pathlib import Path
|
||||
import fire
|
||||
import transformers
|
||||
|
||||
from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art
|
||||
from axolotl.cli import (
|
||||
do_inference,
|
||||
do_inference_gradio,
|
||||
load_cfg,
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
@@ -21,7 +26,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
)
|
||||
parsed_cli_args.inference = True
|
||||
|
||||
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
if gradio:
|
||||
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
53
src/axolotl/cli/preprocess.py
Normal file
53
src/axolotl/cli/preprocess.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
from colorama import Fore
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
check_user_token,
|
||||
load_cfg,
|
||||
load_datasets,
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import PreprocessCliArgs
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
parser = transformers.HfArgumentParser((PreprocessCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
if not parsed_cfg.dataset_prepared_path:
|
||||
msg = (
|
||||
Fore.RED
|
||||
+ "preprocess CLI called without dataset_prepared_path set, "
|
||||
+ f"using default path: {DEFAULT_DATASET_PREPARED_PATH}"
|
||||
+ Fore.RESET
|
||||
)
|
||||
LOG.warning(msg)
|
||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
LOG.info(
|
||||
Fore.GREEN
|
||||
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
||||
+ Fore.RESET
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(do_cli)
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
from colorama import Fore
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
@@ -16,7 +15,6 @@ from axolotl.cli import (
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.train import train
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.train")
|
||||
@@ -32,18 +30,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
|
||||
msg = (
|
||||
Fore.RED
|
||||
+ "--prepare_ds_only called without dataset_prepared_path set."
|
||||
+ Fore.RESET
|
||||
)
|
||||
LOG.warning(msg)
|
||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
if parsed_cli_args.prepare_ds_only:
|
||||
return
|
||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
|
||||
@@ -25,11 +25,22 @@ class TrainerCliArgs:
|
||||
debug_num_examples: int = field(default=5)
|
||||
inference: bool = field(default=False)
|
||||
merge_lora: bool = field(default=False)
|
||||
prepare_ds_only: bool = field(default=False)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
shard: bool = field(default=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessCliArgs:
|
||||
"""
|
||||
dataclass representing arguments for preprocessing only
|
||||
"""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=1)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
|
||||
@@ -6,7 +6,6 @@ import abc
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
@@ -18,9 +17,9 @@ import torch
|
||||
import transformers
|
||||
from datasets import Dataset
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_pt_utils import SequentialDistributedSampler
|
||||
from transformers.trainer_utils import seed_worker
|
||||
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils.callbacks import (
|
||||
@@ -31,8 +30,9 @@ from axolotl.utils.callbacks import (
|
||||
bench_eval_callback_factory,
|
||||
log_prediction_callback_factory,
|
||||
)
|
||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||
from axolotl.utils.samplers import MultipackBatchSampler
|
||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||
|
||||
try:
|
||||
@@ -102,6 +102,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
@@ -111,7 +115,8 @@ class AxolotlTrainer(Trainer):
|
||||
|
||||
args = None # type: AxolotlTrainingArguments
|
||||
|
||||
def __init__(self, *args, bench_data_collator=None, **kwargs):
|
||||
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
||||
self.num_epochs = num_epochs
|
||||
self.bench_data_collator = bench_data_collator
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -144,45 +149,69 @@ class AxolotlTrainer(Trainer):
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size > 1 and self.args.sample_packing:
|
||||
return DistributedSampler(
|
||||
self.train_dataset,
|
||||
num_replicas=self.args.world_size,
|
||||
rank=self.args.process_index,
|
||||
seed=self.args.seed,
|
||||
if self.args.sample_packing:
|
||||
return MultipackBatchSampler(
|
||||
RandomSampler(self.train_dataset),
|
||||
self.args.train_batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=self._train_batch_size * self.args.max_seq_length,
|
||||
lengths=(
|
||||
self.train_dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if (
|
||||
self.args.world_size > 1
|
||||
and self.args.sample_packing
|
||||
and self.args.eval_sample_packing is not False
|
||||
):
|
||||
return SequentialDistributedSampler(
|
||||
eval_dataset,
|
||||
num_replicas=self.args.world_size,
|
||||
rank=self.args.process_index,
|
||||
batch_size=self.args.per_device_eval_batch_size,
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
return MultipackBatchSampler(
|
||||
SequentialSampler(eval_dataset),
|
||||
self.args.per_device_eval_batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
|
||||
lengths=(
|
||||
eval_dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.args.sample_packing:
|
||||
train_sampler = self._get_train_sampler()
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
self.train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=train_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
train_dataset = self.train_dataset
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
sampler = self._get_train_sampler()
|
||||
if isinstance(sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
@@ -195,17 +224,29 @@ class AxolotlTrainer(Trainer):
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
eval_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=eval_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if isinstance(eval_sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = eval_sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = eval_sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(eval_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
@@ -226,6 +267,8 @@ class AxolotlTrainer(Trainer):
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||
@@ -440,6 +483,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["fp16"] = (
|
||||
self.cfg.fp16 and not self.cfg.bf16
|
||||
) or False
|
||||
if self.cfg.fp8:
|
||||
training_arguments_kwargs["fp16"] = False
|
||||
training_arguments_kwargs["bf16"] = False
|
||||
|
||||
training_arguments_kwargs["tf32"] = self.cfg.tf32
|
||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||
@@ -490,6 +537,19 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
"sample_packing_efficiency"
|
||||
] = self.cfg.sample_packing_eff_est
|
||||
|
||||
if self.cfg.dataloader_pin_memory is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_pin_memory"
|
||||
] = self.cfg.dataloader_pin_memory
|
||||
if self.cfg.dataloader_num_workers is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_num_workers"
|
||||
] = self.cfg.dataloader_num_workers
|
||||
if self.cfg.dataloader_prefetch_factor is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_prefetch_factor"
|
||||
] = self.cfg.dataloader_prefetch_factor
|
||||
|
||||
if self.cfg.eval_steps:
|
||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
@@ -669,7 +729,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
args=training_args,
|
||||
data_collator=DataCollatorForSeq2Seq(
|
||||
data_collator=BatchSamplerDataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
@@ -680,10 +740,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
callbacks=self.get_callbacks(),
|
||||
num_epochs=self.cfg.num_epochs,
|
||||
**trainer_kwargs,
|
||||
)
|
||||
trainer = self.hook_post_create_trainer(trainer)
|
||||
for callback in self.get_post_trainer_create_callbacks(trainer):
|
||||
trainer.add_callback(callback)
|
||||
|
||||
if self.cfg.deepspeed and self.cfg.sample_packing:
|
||||
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
|
||||
"train_micro_batch_size_per_gpu"
|
||||
] = self.cfg.micro_batch_size
|
||||
|
||||
return trainer
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
@@ -30,14 +30,20 @@ class TokenizedPromptDataset(Dataset):
|
||||
self,
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: IterableDataset,
|
||||
process_count: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.prompt_tokenizer = prompt_tokenizer
|
||||
self.process_count = process_count
|
||||
super().__init__(self.process(dataset).data, **kwargs)
|
||||
|
||||
def process(self, dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = min(64, os.cpu_count())
|
||||
num_proc = (
|
||||
min(64, self.process_count)
|
||||
if self.process_count
|
||||
else min(64, os.cpu_count())
|
||||
)
|
||||
map_kwargs = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
"""
|
||||
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||
"""
|
||||
|
||||
import torch
|
||||
import transformers.models.llama.modeling_llama
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
|
||||
# pylint: disable=duplicate-code
|
||||
def noised_embed(orig_embed, noise_alpha, model):
|
||||
def new_func(input_ids):
|
||||
# during training, we add noise to the embedding
|
||||
# during generation, we don't add noise to the embedding
|
||||
if model.training:
|
||||
embed_init = orig_embed(input_ids)
|
||||
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
||||
mag_norm = noise_alpha / torch.sqrt(dims)
|
||||
return embed_init + torch.zeros_like(embed_init).uniform_(
|
||||
-mag_norm, mag_norm
|
||||
)
|
||||
return orig_embed(input_ids)
|
||||
|
||||
return new_func
|
||||
|
||||
def post_init(orig_post_init):
|
||||
def new_func(self):
|
||||
orig_post_init(self)
|
||||
self.embed_tokens.forward = noised_embed(
|
||||
self.embed_tokens.forward, noise_alpha, self
|
||||
)
|
||||
|
||||
return new_func
|
||||
|
||||
transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
|
||||
transformers.models.llama.modeling_llama.LlamaModel.post_init
|
||||
)
|
||||
@@ -1,40 +0,0 @@
|
||||
"""
|
||||
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||
"""
|
||||
|
||||
import torch
|
||||
import transformers.models.mistral.modeling_mistral
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
|
||||
# pylint: disable=duplicate-code
|
||||
def noised_embed(orig_embed, noise_alpha, model):
|
||||
def new_func(input_ids):
|
||||
# during training, we add noise to the embedding
|
||||
# during generation, we don't add noise to the embedding
|
||||
if model.training:
|
||||
embed_init = orig_embed(input_ids)
|
||||
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
||||
mag_norm = noise_alpha / torch.sqrt(dims)
|
||||
return embed_init + torch.zeros_like(embed_init).uniform_(
|
||||
-mag_norm, mag_norm
|
||||
)
|
||||
return orig_embed(input_ids)
|
||||
|
||||
return new_func
|
||||
|
||||
def post_init(orig_post_init):
|
||||
def new_func(self):
|
||||
orig_post_init(self)
|
||||
self.embed_tokens.forward = noised_embed(
|
||||
self.embed_tokens.forward, noise_alpha, self
|
||||
)
|
||||
|
||||
return new_func
|
||||
|
||||
transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
|
||||
transformers.models.mistral.modeling_mistral.MistralModel.post_init
|
||||
)
|
||||
65
src/axolotl/monkeypatch/neft_embeddings.py
Normal file
65
src/axolotl/monkeypatch/neft_embeddings.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||
"""
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
def patch_neft(alpha, model):
|
||||
embeddings = None
|
||||
if isinstance(model, PreTrainedModel):
|
||||
embeddings = model.get_input_embeddings()
|
||||
if isinstance(model, PeftModel):
|
||||
embeddings = model.base_model.get_input_embeddings()
|
||||
if not embeddings:
|
||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
||||
embeddings.noisy_embedding_alpha = alpha
|
||||
old_forward = embeddings.forward
|
||||
|
||||
# This hack seems to be needed to properly use a custom forward pass
|
||||
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
|
||||
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
|
||||
embeddings, embeddings.__class__
|
||||
)
|
||||
setattr(embeddings, "forward", bound_method)
|
||||
|
||||
embeddings._old_forward = old_forward # pylint: disable=protected-access
|
||||
return model
|
||||
|
||||
|
||||
def unpatch_neft(model):
|
||||
embeddings = None
|
||||
if isinstance(model, PreTrainedModel):
|
||||
embeddings = model.get_input_embeddings()
|
||||
if isinstance(model, PeftModel):
|
||||
embeddings = model.base_model.get_input_embeddings()
|
||||
if not embeddings:
|
||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
||||
if hasattr(embeddings, "_old_forward"):
|
||||
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
|
||||
del embeddings._old_forward # pylint: disable=protected-access
|
||||
del embeddings.noisy_embedding_alpha
|
||||
|
||||
|
||||
def neft_forward(self, inputs: torch.Tensor):
|
||||
embeddings = self._old_forward(inputs) # pylint: disable=protected-access
|
||||
|
||||
if self.training:
|
||||
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
|
||||
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
|
||||
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
|
||||
-mag_norm, mag_norm
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def pretrain_hook(cfg, trainer):
|
||||
if cfg.noisy_embedding_alpha:
|
||||
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
|
||||
|
||||
|
||||
def post_train_hook(cfg, trainer):
|
||||
if cfg.noisy_embedding_alpha:
|
||||
unpatch_neft(trainer.model)
|
||||
@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
)
|
||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||
return SimpleShareGPTPromptTokenizingStrategy(
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
role_key_model=field_model,
|
||||
@@ -34,6 +34,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
if ds_cfg and "strict" in ds_cfg:
|
||||
strategy.strict = ds_cfg["strict"]
|
||||
return strategy
|
||||
|
||||
|
||||
def load_role(tokenizer, cfg):
|
||||
@@ -59,8 +62,26 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
"""
|
||||
|
||||
_strict = True
|
||||
|
||||
@property
|
||||
def strict(self):
|
||||
return self._strict
|
||||
|
||||
@strict.setter
|
||||
def strict(self, strict):
|
||||
self._strict = strict
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return prompt["conversations"]
|
||||
conversations = prompt["conversations"]
|
||||
if self.strict:
|
||||
return conversations
|
||||
# remap roles - allow for assistant turn
|
||||
role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
|
||||
turns = [
|
||||
{"from": role_map[t["from"]], "value": t["value"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
|
||||
@@ -245,6 +245,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# pylint: disable=duplicate-code
|
||||
(
|
||||
instruction,
|
||||
input, # pylint: disable=redefined-builtin
|
||||
|
||||
@@ -4,10 +4,12 @@ import logging
|
||||
from enum import Enum
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
from colorama import Fore
|
||||
from fastchat.conversation import Conversation, get_conv_template
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
IGNORE_TOKEN_ID = -100
|
||||
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
|
||||
|
||||
|
||||
class PromptStyle(Enum):
|
||||
@@ -55,20 +57,15 @@ class AlpacaPrompter:
|
||||
)
|
||||
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
def _build_result(self, instruction, input_text, output):
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
if input_text:
|
||||
res = (
|
||||
self.system_format.format(system=self.system_prompt)
|
||||
if self.system_prompt
|
||||
else ""
|
||||
) + self.turn_format.format(instruction=instruction, input=input)
|
||||
) + self.turn_format.format(instruction=instruction, input=input_text)
|
||||
else:
|
||||
res = (
|
||||
self.system_format.format(system=self.system_no_input_prompt)
|
||||
@@ -77,7 +74,21 @@ class AlpacaPrompter:
|
||||
) + self.turn_no_input_format.format(instruction=instruction)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
yield res
|
||||
|
||||
return res
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
yield self._build_result(instruction, input, output)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return REPR_TEMPLATE.format(
|
||||
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
||||
)
|
||||
|
||||
|
||||
class UnpromptedPrompter(AlpacaPrompter):
|
||||
@@ -191,14 +202,14 @@ class ReflectAlpacaPrompter:
|
||||
)
|
||||
self.response_split = "ASSISTANT:"
|
||||
|
||||
def build_prompt(
|
||||
def _build_result(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
reflection: Union[None, str] = None,
|
||||
corrected: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
):
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
@@ -212,7 +223,30 @@ class ReflectAlpacaPrompter:
|
||||
corrected=corrected,
|
||||
)
|
||||
res = f"{res}{label}"
|
||||
yield res
|
||||
|
||||
return res
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
reflection: Union[None, str] = None,
|
||||
corrected: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
# pylint: disable=duplicate-code
|
||||
yield self._build_result(
|
||||
instruction,
|
||||
input,
|
||||
output,
|
||||
reflection,
|
||||
corrected,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return REPR_TEMPLATE.format(
|
||||
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
||||
)
|
||||
|
||||
|
||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||
@@ -247,7 +281,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
if role_key_model:
|
||||
self.role_key_model = role_key_model
|
||||
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
def _build_result(self, source):
|
||||
if len(source) < 2:
|
||||
# If there isn't a back and forth conversation, ignore it
|
||||
# also happens on the data splitting leaving empty conversations
|
||||
@@ -282,11 +316,20 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||
conv.append_message(role, sentence["value"])
|
||||
|
||||
for part in conv.get_turns():
|
||||
return conv.get_turns()
|
||||
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
turns = self._build_result(source)
|
||||
|
||||
for part in turns:
|
||||
if part[0] and not part[1]:
|
||||
LOG.warning(f"role with empty message: {part[0]}")
|
||||
yield part
|
||||
|
||||
def __repr__(self) -> str:
|
||||
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
|
||||
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
|
||||
|
||||
|
||||
class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||
"""
|
||||
@@ -304,3 +347,15 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||
role_key_human=role_key_human,
|
||||
role_key_model=role_key_model,
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedPrompter:
|
||||
"""
|
||||
A dummy class for custom prompters
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "Pre-tokenized or custom dataset types are unsupported for logging"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
@@ -10,12 +9,14 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import transformers.modelcard
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.monkeypatch import neft_embeddings
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
@@ -25,7 +26,7 @@ src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
configure_logging()
|
||||
LOG = logging.getLogger("axolotl.train")
|
||||
LOG = get_logger("axolotl.train")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -43,7 +44,10 @@ def train(
|
||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||
):
|
||||
# load the tokenizer first
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
main_process_only=True,
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
@@ -51,7 +55,10 @@ def train(
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
# Load the model and tokenizer
|
||||
LOG.info("loading model and (optionally) peft_config...")
|
||||
msg = "loading model"
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
LOG.debug(msg)
|
||||
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
@@ -107,6 +114,7 @@ def train(
|
||||
if cfg.group_by_length:
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
|
||||
pretrain_hooks(cfg, trainer)
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||
@@ -114,6 +122,7 @@ def train(
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
post_train_hooks(cfg, trainer)
|
||||
|
||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
@@ -163,3 +172,23 @@ def train(
|
||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def pretrain_hooks(cfg, trainer):
|
||||
"""
|
||||
Run hooks right before kicking off the training
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
neft_embeddings.pretrain_hook(cfg, trainer)
|
||||
|
||||
|
||||
def post_train_hooks(cfg, trainer):
|
||||
"""
|
||||
Run hooks right after training completes
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
neft_embeddings.post_train_hook(cfg, trainer)
|
||||
|
||||
@@ -119,3 +119,30 @@ class DataCollatorForSeq2Seq:
|
||||
features["decoder_input_ids"] = decoder_input_ids
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
Collator for multipack specific to the using the BatchSampler
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
chunked_data = {}
|
||||
for feature in features[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(1) * np.array(item[feature])
|
||||
for item in features
|
||||
if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
@@ -70,7 +70,9 @@ def normalize_config(cfg):
|
||||
else:
|
||||
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
||||
|
||||
if cfg.bf16 or cfg.bfloat16:
|
||||
if cfg.fp8:
|
||||
cfg.torch_dtype = torch.bfloat16
|
||||
elif cfg.bf16 or cfg.bfloat16:
|
||||
cfg.torch_dtype = torch.bfloat16
|
||||
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
||||
cfg.torch_dtype = torch.float16
|
||||
|
||||
@@ -3,7 +3,7 @@ import functools
|
||||
import hashlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from datasets import (
|
||||
@@ -36,6 +36,7 @@ from axolotl.prompters import (
|
||||
MultipleChoiceExplainPrompter,
|
||||
ReflectAlpacaPrompter,
|
||||
SummarizeTLDRPrompter,
|
||||
UnsupportedPrompter,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
@@ -55,9 +56,10 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
|
||||
|
||||
def prepare_dataset(cfg, tokenizer):
|
||||
prompters = []
|
||||
if not cfg.pretraining_dataset:
|
||||
with zero_first(is_main_process()):
|
||||
train_dataset, eval_dataset = load_prepare_datasets(
|
||||
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
else:
|
||||
@@ -70,7 +72,7 @@ def prepare_dataset(cfg, tokenizer):
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
eval_dataset = None
|
||||
return train_dataset, eval_dataset, cfg.max_steps
|
||||
return train_dataset, eval_dataset, cfg.max_steps, prompters
|
||||
|
||||
with zero_first(is_main_process()):
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
@@ -78,12 +80,12 @@ def prepare_dataset(cfg, tokenizer):
|
||||
)
|
||||
if cfg.max_steps:
|
||||
total_num_steps = min(
|
||||
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
||||
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||
)
|
||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||
else:
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||
return train_dataset, eval_dataset, total_num_steps
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
||||
return train_dataset, eval_dataset, total_num_steps, prompters
|
||||
|
||||
|
||||
def load_tokenized_prepared_datasets(
|
||||
@@ -109,6 +111,7 @@ def load_tokenized_prepared_datasets(
|
||||
else Path(default_dataset_prepared_path) / ds_hash
|
||||
)
|
||||
dataset = None
|
||||
prompters = []
|
||||
use_auth_token = cfg.hf_use_auth_token
|
||||
try:
|
||||
if cfg.push_dataset_to_hub:
|
||||
@@ -147,13 +150,13 @@ def load_tokenized_prepared_datasets(
|
||||
yield dataset
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
for d in for_d_in_datasets(cfg.datasets):
|
||||
for config_dataset in for_d_in_datasets(cfg.datasets):
|
||||
ds: Union[Dataset, DatasetDict] = None
|
||||
ds_from_hub = False
|
||||
try:
|
||||
load_dataset(
|
||||
d.path,
|
||||
name=d.name,
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=True,
|
||||
token=use_auth_token,
|
||||
)
|
||||
@@ -162,33 +165,33 @@ def load_tokenized_prepared_datasets(
|
||||
pass
|
||||
|
||||
# prefer local dataset, even if hub exists
|
||||
local_path = Path(d.path)
|
||||
local_path = Path(config_dataset.path)
|
||||
if local_path.exists():
|
||||
if local_path.is_dir():
|
||||
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
||||
ds = load_dataset(
|
||||
d.path,
|
||||
name=d.name,
|
||||
data_files=d.data_files,
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.data_files,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds_type = "json"
|
||||
if d.ds_type:
|
||||
ds_type = d.ds_type
|
||||
elif ".parquet" in d.path:
|
||||
if config_dataset.ds_type:
|
||||
ds_type = config_dataset.ds_type
|
||||
elif ".parquet" in config_dataset.path:
|
||||
ds_type = "parquet"
|
||||
elif ".arrow" in d.path:
|
||||
elif ".arrow" in config_dataset.path:
|
||||
ds_type = "arrow"
|
||||
elif ".csv" in d.path:
|
||||
elif ".csv" in config_dataset.path:
|
||||
ds_type = "csv"
|
||||
elif ".txt" in d.path:
|
||||
elif ".txt" in config_dataset.path:
|
||||
ds_type = "text"
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=d.name,
|
||||
data_files=d.path,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
@@ -198,25 +201,25 @@ def load_tokenized_prepared_datasets(
|
||||
)
|
||||
elif ds_from_hub:
|
||||
ds = load_dataset(
|
||||
d.path,
|
||||
name=d.name,
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=False,
|
||||
data_files=d.data_files,
|
||||
data_files=config_dataset.data_files,
|
||||
token=use_auth_token,
|
||||
)
|
||||
else:
|
||||
if isinstance(d.data_files, str):
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
fp = hf_hub_download(
|
||||
repo_id=d.path,
|
||||
repo_id=config_dataset.path,
|
||||
repo_type="dataset",
|
||||
filename=d.data_files,
|
||||
filename=config_dataset.data_files,
|
||||
)
|
||||
elif isinstance(d.data_files, list):
|
||||
elif isinstance(config_dataset.data_files, list):
|
||||
fp = []
|
||||
for file in d.data_files:
|
||||
for file in config_dataset.data_files:
|
||||
fp.append(
|
||||
hf_hub_download(
|
||||
repo_id=d.path,
|
||||
repo_id=config_dataset.path,
|
||||
repo_type="dataset",
|
||||
filename=file,
|
||||
)
|
||||
@@ -226,21 +229,27 @@ def load_tokenized_prepared_datasets(
|
||||
"data_files must be either a string or list of strings"
|
||||
)
|
||||
ds = load_dataset(
|
||||
"json", name=d.name, data_files=fp, streaming=False, split=None
|
||||
"json",
|
||||
name=config_dataset.name,
|
||||
data_files=fp,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
if not ds:
|
||||
raise ValueError("unhandled dataset load")
|
||||
# support for using a subset of the data
|
||||
if d.shards:
|
||||
if config_dataset.shards:
|
||||
if "train" in ds:
|
||||
ds = ds.shuffle(seed=seed)["train"].shard(
|
||||
num_shards=d.shards, index=0
|
||||
num_shards=config_dataset.shards, index=0
|
||||
)
|
||||
else:
|
||||
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
||||
ds = ds.shuffle(seed=seed).shard(
|
||||
num_shards=config_dataset.shards, index=0
|
||||
)
|
||||
|
||||
d_base_type = d_prompt_style = None
|
||||
d_type = d.type
|
||||
d_type = config_dataset.type
|
||||
if isinstance(d_type, str):
|
||||
d_type_split = d_type.split(":")
|
||||
d_base_type = d_type_split[0]
|
||||
@@ -249,108 +258,26 @@ def load_tokenized_prepared_datasets(
|
||||
ds = ds["train"]
|
||||
elif (
|
||||
isinstance(ds, DatasetDict)
|
||||
and d.train_on_split
|
||||
and d.train_on_split in ds
|
||||
and config_dataset.train_on_split
|
||||
and config_dataset.train_on_split in ds
|
||||
):
|
||||
ds = ds[d.train_on_split]
|
||||
ds = ds[config_dataset.train_on_split]
|
||||
elif isinstance(ds, DatasetDict):
|
||||
raise ValueError(
|
||||
f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `"
|
||||
)
|
||||
if (
|
||||
"input_ids" in ds.features
|
||||
and "attention_mask" in ds.features
|
||||
and "labels" in ds.features
|
||||
):
|
||||
# dataset is already tokenized, just drop it straight in
|
||||
datasets.append(ds)
|
||||
elif isinstance(d.type, DictDefault):
|
||||
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif ds_strategy := load(d.type, tokenizer, cfg, d):
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "alpaca":
|
||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||
AlpacaPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "explainchoice":
|
||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
MultipleChoiceExplainPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "concisechoice":
|
||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
MultipleChoiceConcisePrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "summarizetldr":
|
||||
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
||||
SummarizeTLDRPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "jeopardy":
|
||||
ds_strategy = JeopardyPromptTokenizingStrategy(
|
||||
JeopardyPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "oasst":
|
||||
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
||||
AlpacaPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "gpteacher":
|
||||
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
||||
GPTeacherPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "reflection":
|
||||
ds_strategy = AlpacaReflectionPTStrategy(
|
||||
ReflectAlpacaPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
else:
|
||||
suffix = ""
|
||||
if ":load_" in d.type:
|
||||
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
|
||||
LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}")
|
||||
raise ValueError(
|
||||
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
||||
f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `"
|
||||
)
|
||||
|
||||
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
||||
config_dataset=config_dataset,
|
||||
dataset=ds,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
d_base_type=d_base_type,
|
||||
d_prompt_style=d_prompt_style,
|
||||
)
|
||||
datasets.append(dataset_wrapper)
|
||||
prompters.append(dataset_prompter)
|
||||
|
||||
LOG.info("merging datasets")
|
||||
dataset = concatenate_datasets(datasets)
|
||||
|
||||
@@ -368,14 +295,14 @@ def load_tokenized_prepared_datasets(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||
)
|
||||
|
||||
return dataset
|
||||
return dataset, prompters
|
||||
|
||||
|
||||
def load_prepare_datasets(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
cfg,
|
||||
default_dataset_prepared_path,
|
||||
) -> Tuple[Dataset, Dataset]:
|
||||
) -> Tuple[Dataset, Dataset, List[Any]]:
|
||||
max_packed_sequence_len = (
|
||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||
)
|
||||
@@ -384,6 +311,7 @@ def load_prepare_datasets(
|
||||
) # make sure we don't accidentally set it larger than sequence_len
|
||||
|
||||
tokenizer_name = tokenizer.__class__.__name__
|
||||
prompters = []
|
||||
if cfg.max_packed_sequence_len is not None:
|
||||
# see if we can go ahead and load the stacked dataset
|
||||
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
||||
@@ -439,7 +367,7 @@ def load_prepare_datasets(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||
)
|
||||
else:
|
||||
dataset = load_tokenized_prepared_datasets(
|
||||
dataset, prompters = load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, default_dataset_prepared_path
|
||||
)
|
||||
|
||||
@@ -481,7 +409,7 @@ def load_prepare_datasets(
|
||||
private=True,
|
||||
)
|
||||
else:
|
||||
dataset = load_tokenized_prepared_datasets(
|
||||
dataset, prompters = load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, default_dataset_prepared_path
|
||||
)
|
||||
|
||||
@@ -532,7 +460,144 @@ def load_prepare_datasets(
|
||||
train_dataset = dataset
|
||||
eval_dataset = None
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
return train_dataset, eval_dataset, prompters
|
||||
|
||||
|
||||
def get_dataset_wrapper(
|
||||
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
|
||||
):
|
||||
dataset_wrapper = None
|
||||
dataset_prompter = None
|
||||
|
||||
if (
|
||||
"input_ids" in dataset.features
|
||||
and "attention_mask" in dataset.features
|
||||
and "labels" in dataset.features
|
||||
):
|
||||
# dataset is already tokenized, just drop it straight in
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = dataset
|
||||
elif isinstance(config_dataset.type, DictDefault):
|
||||
ds_strategy = load(
|
||||
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
||||
)
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
elif d_base_type == "alpaca":
|
||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "explainchoice":
|
||||
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "concisechoice":
|
||||
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "summarizetldr":
|
||||
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
||||
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "jeopardy":
|
||||
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
||||
ds_strategy = JeopardyPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "oasst":
|
||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "gpteacher":
|
||||
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
||||
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "reflection":
|
||||
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
||||
ds_strategy = AlpacaReflectionPTStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
else:
|
||||
suffix = ""
|
||||
if ":load_" in config_dataset.type:
|
||||
suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?"
|
||||
LOG.error(
|
||||
f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}"
|
||||
)
|
||||
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def encode_pretraining(
|
||||
|
||||
@@ -3,6 +3,9 @@ import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, List, Union
|
||||
|
||||
import numba
|
||||
@@ -149,6 +152,8 @@ class MultipackDistributedDataloader:
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
sample_packing_seq_len_multiplier: int = 1,
|
||||
device_count: int = 1,
|
||||
prefetch_max: int = 1000,
|
||||
num_epochs: int = 1,
|
||||
):
|
||||
# Dataset
|
||||
self.dataset = dataset
|
||||
@@ -167,6 +172,7 @@ class MultipackDistributedDataloader:
|
||||
self.seq_max_length = seq_max_length
|
||||
self.batch_max_length = batch_size * seq_max_length
|
||||
self.collate_fn = collate_fn
|
||||
self.num_epochs = num_epochs
|
||||
|
||||
self.num_replicas = 1
|
||||
self.rank = 0
|
||||
@@ -177,6 +183,44 @@ class MultipackDistributedDataloader:
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
self.device_count = device_count
|
||||
|
||||
# maxsize is maximum number of samples in queue
|
||||
self.prefetch_max = prefetch_max
|
||||
self.queue: Queue = Queue(maxsize=prefetch_max)
|
||||
self.thread = None
|
||||
|
||||
def _worker(self):
|
||||
LOG.info(
|
||||
f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}"
|
||||
)
|
||||
for epoch in range(self.num_epochs):
|
||||
for sample in self._internal_batch_generator():
|
||||
while True:
|
||||
if self.queue.full():
|
||||
time.sleep(1)
|
||||
else:
|
||||
break
|
||||
self.queue.put(sample)
|
||||
|
||||
# stop the queue when epoch is done
|
||||
self.queue.put(None)
|
||||
|
||||
def __iter__(self):
|
||||
if hasattr(self.sampler, "set_epoch"):
|
||||
new_epoch = self.sampler.epoch + 1
|
||||
self.sampler.set_epoch(new_epoch)
|
||||
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||
|
||||
if self.thread is None:
|
||||
self.thread = Thread(target=self._worker, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
while True:
|
||||
item = self.queue.get()
|
||||
|
||||
if item is None:
|
||||
break
|
||||
yield item
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
LOG.info("generating packed batches")
|
||||
if self.sampler:
|
||||
@@ -206,11 +250,7 @@ class MultipackDistributedDataloader:
|
||||
|
||||
return batches, totseqs
|
||||
|
||||
def __iter__(self):
|
||||
if hasattr(self.sampler, "set_epoch"):
|
||||
new_epoch = self.sampler.epoch + 1
|
||||
self.sampler.set_epoch(new_epoch)
|
||||
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||
def _internal_batch_generator(self):
|
||||
all_batches, _ = self.generate_batches(set_stats=True)
|
||||
features = self.dataset.features.keys()
|
||||
len_remaining = self._len_est()
|
||||
|
||||
@@ -50,6 +50,17 @@ def get_world_size():
|
||||
return int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def zero_only():
|
||||
"""
|
||||
Context manager that only runs the enclosed block on the main rank.
|
||||
"""
|
||||
if is_main_process():
|
||||
yield
|
||||
else:
|
||||
yield None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def zero_first(is_main):
|
||||
"""
|
||||
|
||||
@@ -31,7 +31,7 @@ LOG = logging.getLogger("axolotl")
|
||||
|
||||
def load_model_config(cfg):
|
||||
model_config_name = cfg.base_model_config or cfg.base_model
|
||||
trust_remote_code: bool = False or cfg.trust_remote_code
|
||||
trust_remote_code = cfg.trust_remote_code is True
|
||||
return AutoConfig.from_pretrained(
|
||||
model_config_name, trust_remote_code=trust_remote_code
|
||||
)
|
||||
@@ -72,11 +72,6 @@ def load_tokenizer(cfg):
|
||||
# set a pad_token, but use eos_token so we don't add a new token
|
||||
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
||||
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -98,6 +93,11 @@ def load_tokenizer(cfg):
|
||||
]
|
||||
)
|
||||
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@@ -180,26 +180,6 @@ def load_model(
|
||||
LOG.info("patching with flash attention")
|
||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
|
||||
from axolotl.monkeypatch.llama_embeddings_hijack import (
|
||||
replace_llama_embeddings_with_uniform_distribution,
|
||||
)
|
||||
|
||||
LOG.info("patching with noisy embeddings")
|
||||
replace_llama_embeddings_with_uniform_distribution(
|
||||
noise_alpha=cfg.noisy_embedding_alpha
|
||||
)
|
||||
|
||||
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
|
||||
from axolotl.monkeypatch.mistral_embeddings_hijack import (
|
||||
replace_mistral_embeddings_with_uniform_distribution,
|
||||
)
|
||||
|
||||
LOG.info("patching with noisy embeddings")
|
||||
replace_mistral_embeddings_with_uniform_distribution(
|
||||
noise_alpha=cfg.noisy_embedding_alpha
|
||||
)
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
||||
replace_llama_rope_with_xpos_rope,
|
||||
@@ -406,6 +386,20 @@ def load_model(
|
||||
)
|
||||
model.config.max_position_embeddings = cfg.sequence_len
|
||||
|
||||
if (
|
||||
hasattr(model.config, "bos_token_id")
|
||||
and model.config.bos_token_id
|
||||
and model.config.bos_token_id != tokenizer.bos_token_id
|
||||
):
|
||||
model.config.bos_token_id = tokenizer.bos_token_id
|
||||
|
||||
if (
|
||||
hasattr(model.config, "eos_token_id")
|
||||
and model.config.eos_token_id
|
||||
and model.config.eos_token_id != tokenizer.eos_token_id
|
||||
):
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
if model.device.type == "cuda":
|
||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||
|
||||
@@ -448,14 +442,7 @@ def load_model(
|
||||
if cfg.ddp and not load_in_8bit:
|
||||
model.to(f"cuda:{cfg.local_rank}")
|
||||
|
||||
if (
|
||||
torch.cuda.device_count() > 1
|
||||
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
||||
and (cfg.load_in_4bit)
|
||||
):
|
||||
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
||||
# so let's only set it for the 4bit, see
|
||||
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
||||
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
||||
setattr(model, "is_parallelizable", True)
|
||||
setattr(model, "model_parallel", True)
|
||||
|
||||
|
||||
4
src/axolotl/utils/samplers/__init__.py
Normal file
4
src/axolotl/utils/samplers/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
axolotl samplers module
|
||||
"""
|
||||
from .multipack import MultipackBatchSampler # noqa: F401
|
||||
193
src/axolotl/utils/samplers/multipack.py
Normal file
193
src/axolotl/utils/samplers/multipack.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# pylint: skip-file
|
||||
"""
|
||||
Multipack Batch Sampler
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Iterable, List, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from torch.utils.data import BatchSampler, Sampler
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_check(a: np.ndarray, c: int, n: int):
|
||||
# First-fit-decreasing bin packing
|
||||
# Check if a[] could fit in n bins with capacity c
|
||||
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
||||
|
||||
a = np.sort(a)[::-1]
|
||||
bins = np.full((n,), c, dtype=a.dtype)
|
||||
for size in a:
|
||||
not_found = True
|
||||
for idx in range(n):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
not_found = False
|
||||
break
|
||||
|
||||
if not_found:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
||||
# First-fit-decreasing bin packing (with result return)
|
||||
|
||||
indices = np.argsort(a)[::-1]
|
||||
a = a[indices]
|
||||
|
||||
bins: List[Any] = []
|
||||
bins_result: List[Any] = []
|
||||
for a_id, size in enumerate(a):
|
||||
add_new = True
|
||||
for idx in range(len(bins)):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
bins_result[idx].append(indices[a_id] + start_index)
|
||||
add_new = False
|
||||
break
|
||||
|
||||
if add_new:
|
||||
bins.append(c - size)
|
||||
bins_result.append([indices[a_id] + start_index])
|
||||
|
||||
return bins_result
|
||||
|
||||
|
||||
@numba.njit
|
||||
def allocate(
|
||||
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
||||
):
|
||||
# Dynamic batch allocator, similar to Multifit
|
||||
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
||||
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
||||
|
||||
s = 0
|
||||
start_index = 0
|
||||
result = []
|
||||
|
||||
while True:
|
||||
# binary search [l, r)
|
||||
left = 1
|
||||
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
||||
|
||||
while right - left > 1:
|
||||
mid = (left + right) // 2
|
||||
if ffd_check(lengths[start_index : start_index + mid], c, n):
|
||||
left = mid
|
||||
else:
|
||||
right = mid
|
||||
|
||||
# use length l
|
||||
batch = ffd_with_result(
|
||||
lengths[start_index : start_index + left], c, start_index
|
||||
)
|
||||
assert len(batch) <= n
|
||||
if len(batch) < n:
|
||||
break
|
||||
|
||||
start_index += left
|
||||
s = lengths_cumsum[start_index - 1]
|
||||
|
||||
# add local rank
|
||||
result.append(batch[rank])
|
||||
|
||||
return result, s, len(result) * c * n
|
||||
|
||||
|
||||
class MultipackBatchSampler(BatchSampler):
|
||||
"""
|
||||
Batch Sampler class for multipack
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampler: Union[Sampler[int], Iterable[int]],
|
||||
batch_size: int,
|
||||
drop_last: bool,
|
||||
batch_max_len: int,
|
||||
lengths: np.ndarray,
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
self.batch_size = None
|
||||
self.batch_max_len = batch_max_len
|
||||
self.lengths: np.ndarray = lengths
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
|
||||
self.epoch = 0
|
||||
|
||||
# statistics
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
self.epoch = epoch
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
indices = [idx for idx in self.sampler]
|
||||
|
||||
lengths = self.lengths[indices]
|
||||
lengths_cumsum = np.cumsum(lengths)
|
||||
|
||||
batches, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=0,
|
||||
c=self.batch_max_len,
|
||||
n=1,
|
||||
)
|
||||
|
||||
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
||||
|
||||
# statistics
|
||||
if set_stats:
|
||||
self.eff_total_used += total_used
|
||||
self.eff_total_slots += total_slots
|
||||
|
||||
return batches
|
||||
|
||||
def __iter__(self):
|
||||
batches = self.generate_batches(set_stats=True)
|
||||
return iter(batches)
|
||||
|
||||
def num_batches(self):
|
||||
batches = self.generate_batches(set_stats=True)
|
||||
return len(batches)
|
||||
|
||||
def efficiency(self):
|
||||
return self.eff_total_used / self.eff_total_slots
|
||||
|
||||
def __len__(self):
|
||||
self.num_batches()
|
||||
return self._len_est()
|
||||
|
||||
def _len_est(self):
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
lengths_sum = np.sum(self.lengths)
|
||||
lengths_sum_per_device = lengths_sum // world_size
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||
)
|
||||
|
||||
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||
return (
|
||||
world_size
|
||||
* math.floor(
|
||||
0.99
|
||||
* lengths_sum_per_device
|
||||
/ self.packing_efficiency_estimate
|
||||
// self.batch_max_len
|
||||
)
|
||||
- 1
|
||||
)
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Module containing the Trainer class and related functions"""
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
@@ -9,21 +8,15 @@ from typing import List
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.distributed as dist
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import set_caching_enabled
|
||||
from torch.utils.data import DistributedSampler, RandomSampler
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder
|
||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||
from axolotl.utils.distributed import (
|
||||
is_distributed,
|
||||
is_main_process,
|
||||
reduce_and_broadcast,
|
||||
zero_first,
|
||||
)
|
||||
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
||||
from axolotl.utils.samplers import MultipackBatchSampler
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = get_logger("axolotl")
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@@ -148,19 +141,18 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
def calculate_total_num_steps(cfg, train_dataset):
|
||||
if cfg.sample_packing:
|
||||
# we have to drop anything longer then sequence len otherwise
|
||||
# flash attention with position ids fails
|
||||
if not cfg.total_num_tokens:
|
||||
LOG.info("calculating total_num_tokens")
|
||||
total_num_tokens = np.sum(
|
||||
train_dataset.data.column("input_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
||||
.values
|
||||
)
|
||||
LOG.info(f"total_num_tokens: {total_num_tokens}")
|
||||
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
|
||||
cfg.total_num_tokens = total_num_tokens
|
||||
|
||||
if not cfg.total_supervised_tokens:
|
||||
@@ -170,7 +162,10 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
.apply(lambda x: np.sum(np.array(x) != -100))
|
||||
.sum()
|
||||
)
|
||||
LOG.info(f"`total_supervised_tokens: {total_supervised_tokens}`")
|
||||
LOG.debug(
|
||||
f"`total_supervised_tokens: {total_supervised_tokens}`",
|
||||
main_process_only=True,
|
||||
)
|
||||
cfg.total_supervised_tokens = total_supervised_tokens
|
||||
|
||||
if cfg.sample_packing_eff_est:
|
||||
@@ -189,40 +184,41 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
)
|
||||
* cfg.num_epochs
|
||||
)
|
||||
LOG.info(
|
||||
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
|
||||
LOG.debug(
|
||||
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}",
|
||||
main_process_only=True,
|
||||
)
|
||||
else:
|
||||
if cfg.world_size > 1 and is_distributed():
|
||||
sampler = DistributedSampler(
|
||||
train_dataset,
|
||||
num_replicas=cfg.world_size,
|
||||
rank=dist.get_rank(),
|
||||
seed=cfg.seed or 42,
|
||||
)
|
||||
else:
|
||||
sampler = RandomSampler(train_dataset)
|
||||
|
||||
data_loader = MultipackDistributedDataloader(
|
||||
train_dataset,
|
||||
sampler = MultipackBatchSampler(
|
||||
sampler=RandomSampler(train_dataset),
|
||||
batch_size=cfg.micro_batch_size,
|
||||
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
|
||||
collate_fn=DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
drop_last=True,
|
||||
batch_max_len=cfg.micro_batch_size
|
||||
* (cfg.max_packed_sequence_len or cfg.sequence_len),
|
||||
lengths=(
|
||||
train_dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
),
|
||||
sampler=sampler,
|
||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
data_loader_len = data_loader.len_w_stats()
|
||||
actual_eff = data_loader.efficiency()
|
||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||
|
||||
data_loader = DataLoader(
|
||||
train_dataset.remove_columns(["length"]),
|
||||
batch_sampler=sampler,
|
||||
)
|
||||
data_loader_len = len(data_loader)
|
||||
actual_eff = sampler.efficiency()
|
||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||
# on the agreed on value for sample_packing_eff_est
|
||||
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
||||
total_num_steps = int(
|
||||
math.floor(
|
||||
data_loader_len
|
||||
* cfg.num_epochs
|
||||
/ int(os.environ.get("WORLD_SIZE", 1))
|
||||
)
|
||||
)
|
||||
|
||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||
@@ -236,12 +232,20 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
||||
)
|
||||
cfg.sample_packing_eff_est = sample_packing_eff_est
|
||||
LOG.info(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
|
||||
LOG.debug(
|
||||
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
|
||||
main_process_only=True,
|
||||
)
|
||||
else:
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
math.ceil(
|
||||
len(train_dataset)
|
||||
* cfg.num_epochs
|
||||
/ int(os.environ.get("WORLD_SIZE", 1))
|
||||
/ cfg.batch_size
|
||||
)
|
||||
)
|
||||
LOG.info(f"total_num_steps: {total_num_steps}")
|
||||
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
||||
return total_num_steps
|
||||
|
||||
|
||||
@@ -264,6 +268,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
setup_fsdp_envs(cfg)
|
||||
elif cfg.deepspeed:
|
||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
if cfg.fp8:
|
||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
|
||||
|
||||
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
||||
trainer_builder.train_dataset = train_dataset
|
||||
|
||||
0
tests/e2e/__init__.py
Normal file
0
tests/e2e/__init__.py
Normal file
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
@@ -16,6 +15,8 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -25,9 +26,9 @@ class TestFusedLlama(unittest.TestCase):
|
||||
Test case for Llama models using Fused layers
|
||||
"""
|
||||
|
||||
def test_fft_packing(self):
|
||||
@with_temp_dir
|
||||
def test_fft_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
@@ -51,7 +52,7 @@ class TestFusedLlama(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -69,4 +70,4 @@ class TestFusedLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
@@ -14,6 +13,8 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -23,9 +24,9 @@ class TestLoraLlama(unittest.TestCase):
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
def test_lora(self):
|
||||
@with_temp_dir
|
||||
def test_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
@@ -52,7 +53,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -63,11 +64,11 @@ class TestLoraLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
def test_lora_packing(self):
|
||||
@with_temp_dir
|
||||
def test_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
@@ -96,7 +97,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -107,11 +108,11 @@ class TestLoraLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
def test_lora_gptq(self):
|
||||
@with_temp_dir
|
||||
def test_lora_gptq(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
||||
@@ -144,7 +145,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"save_steps": 0.5,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -155,4 +156,4 @@ class TestLoraLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
@@ -16,6 +15,8 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -25,9 +26,9 @@ class TestMistral(unittest.TestCase):
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
def test_lora(self):
|
||||
@with_temp_dir
|
||||
def test_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
@@ -54,7 +55,7 @@ class TestMistral(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -68,11 +69,11 @@ class TestMistral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
def test_ft(self):
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
@@ -93,7 +94,7 @@ class TestMistral(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -111,4 +112,4 @@ class TestMistral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
@@ -16,6 +15,8 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -25,9 +26,9 @@ class TestMistral(unittest.TestCase):
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
def test_lora_packing(self):
|
||||
@with_temp_dir
|
||||
def test_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
@@ -55,7 +56,7 @@ class TestMistral(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -69,11 +70,11 @@ class TestMistral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
def test_ft_packing(self):
|
||||
@with_temp_dir
|
||||
def test_ft_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
@@ -95,7 +96,7 @@ class TestMistral(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -113,4 +114,4 @@ class TestMistral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
|
||||
@@ -4,8 +4,8 @@ E2E tests for lora llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -13,6 +13,8 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -22,7 +24,8 @@ class TestPhi(unittest.TestCase):
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
def test_ft(self):
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -52,7 +55,7 @@ class TestPhi(unittest.TestCase):
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": tempfile.mkdtemp(),
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -64,8 +67,10 @@ class TestPhi(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
|
||||
def test_ft_packed(self):
|
||||
@with_temp_dir
|
||||
def test_ft_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -95,7 +100,7 @@ class TestPhi(unittest.TestCase):
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": tempfile.mkdtemp(),
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -107,3 +112,4 @@ class TestPhi(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
|
||||
22
tests/e2e/utils.py
Normal file
22
tests/e2e/utils.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
helper utils for tests
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
from functools import wraps
|
||||
|
||||
|
||||
def with_temp_dir(test_func):
|
||||
@wraps(test_func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Create a temporary directory
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
# Pass the temporary directory to the test function
|
||||
test_func(*args, temp_dir=temp_dir, **kwargs)
|
||||
finally:
|
||||
# Clean up the directory after the test
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
return wrapper
|
||||
Reference in New Issue
Block a user