Compare commits

..

8 Commits

Author SHA1 Message Date
sunny
bfb80a3ef9 stuff 2024-10-30 13:44:06 -04:00
sunny
38773d661f fixing 2024-10-30 11:04:50 -04:00
sunny
271c2c2b82 fixed formatting 2024-10-29 15:50:56 -04:00
sunny
32b6f30947 fix attempt at issue 1991 2024-10-29 15:44:32 -04:00
sunny
fc1f275e6c yml change 2024-10-29 15:27:42 -04:00
sunny
46d2b4ce89 yml change 2024-10-29 15:25:25 -04:00
sunny
88c9a7aecc LOG for debug 2024-10-29 13:35:55 -04:00
sunny
d9a93990d1 yml 2024-10-29 10:40:32 -04:00
34 changed files with 572 additions and 760 deletions

View File

@@ -40,7 +40,7 @@ jobs:
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.5.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout

View File

@@ -82,6 +82,13 @@ jobs:
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -72,52 +72,12 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests-1st:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [pre-commit, pytest]
strategy:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.tests
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [pre-commit, pytest, docker-e2e-tests-1st]
needs: [pre-commit, pytest]
strategy:
fail-fast: false
@@ -129,6 +89,18 @@ jobs:
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

295
1991.yml Normal file
View File

@@ -0,0 +1,295 @@
base_model: Qwen/Qwen2.5-14B-Instruct
model_type: AutoModelForCausalLM #nohup accelerate launch -m axolotl.cli.train /home/ubuntu/qwen2.5_14B.yml > training_output.log 2>&1 &
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
chat_template: chatml
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
unfrozen_parameters:
- ^lm_head.weight$
- ^model.embed_tokens.weight$
# input_layernorm layers
- model.layers.0.input_layernorm
- model.layers.1.input_layernorm
- model.layers.2.input_layernorm
- model.layers.3.input_layernorm
- model.layers.4.input_layernorm
- model.layers.5.input_layernorm
- model.layers.6.input_layernorm
- model.layers.7.input_layernorm
- model.layers.8.input_layernorm
- model.layers.9.input_layernorm
- model.layers.10.input_layernorm
- model.layers.11.input_layernorm
- model.layers.12.input_layernorm
- model.layers.13.input_layernorm
- model.layers.14.input_layernorm
- model.layers.15.input_layernorm
- model.layers.16.input_layernorm
- model.layers.17.input_layernorm
- model.layers.18.input_layernorm
- model.layers.19.input_layernorm
- model.layers.20.input_layernorm
- model.layers.21.input_layernorm
- model.layers.22.input_layernorm
- model.layers.23.input_layernorm
# lm_head layers
# mlp.down_proj layers
- model.layers.1.mlp.down_proj
- model.layers.35.mlp.down_proj
- model.layers.38.mlp.down_proj
- model.layers.37.mlp.down_proj
- model.layers.36.mlp.down_proj
- model.layers.15.mlp.down_proj
- model.layers.11.mlp.down_proj
- model.layers.12.mlp.down_proj
- model.layers.34.mlp.down_proj
- model.layers.44.mlp.down_proj
- model.layers.45.mlp.down_proj
- model.layers.9.mlp.down_proj
- model.layers.41.mlp.down_proj
- model.layers.33.mlp.down_proj
- model.layers.43.mlp.down_proj
- model.layers.40.mlp.down_proj
- model.layers.13.mlp.down_proj
- model.layers.8.mlp.down_proj
- model.layers.39.mlp.down_proj
- model.layers.10.mlp.down_proj
- model.layers.14.mlp.down_proj
- model.layers.16.mlp.down_proj
- model.layers.31.mlp.down_proj
- model.layers.32.mlp.down_proj
# mlp.gate_proj layers
- model.layers.1.mlp.gate_proj
- model.layers.44.mlp.gate_proj
- model.layers.46.mlp.gate_proj
- model.layers.45.mlp.gate_proj
- model.layers.43.mlp.gate_proj
- model.layers.47.mlp.gate_proj
- model.layers.42.mlp.gate_proj
- model.layers.32.mlp.gate_proj
- model.layers.27.mlp.gate_proj
- model.layers.33.mlp.gate_proj
- model.layers.28.mlp.gate_proj
- model.layers.39.mlp.gate_proj
- model.layers.41.mlp.gate_proj
- model.layers.40.mlp.gate_proj
- model.layers.30.mlp.gate_proj
- model.layers.29.mlp.gate_proj
- model.layers.31.mlp.gate_proj
- model.layers.26.mlp.gate_proj
- model.layers.37.mlp.gate_proj
- model.layers.10.mlp.gate_proj
- model.layers.38.mlp.gate_proj
- model.layers.12.mlp.gate_proj
- model.layers.36.mlp.gate_proj
- model.layers.13.mlp.gate_proj
# mlp.up_proj layers
- model.layers.1.mlp.up_proj
- model.layers.13.mlp.up_proj
- model.layers.11.mlp.up_proj
- model.layers.14.mlp.up_proj
- model.layers.15.mlp.up_proj
- model.layers.12.mlp.up_proj
- model.layers.8.mlp.up_proj
- model.layers.16.mlp.up_proj
- model.layers.9.mlp.up_proj
- model.layers.19.mlp.up_proj
- model.layers.10.mlp.up_proj
- model.layers.7.mlp.up_proj
- model.layers.17.mlp.up_proj
- model.layers.20.mlp.up_proj
- model.layers.21.mlp.up_proj
- model.layers.18.mlp.up_proj
- model.layers.38.mlp.up_proj
- model.layers.37.mlp.up_proj
- model.layers.39.mlp.up_proj
- model.layers.42.mlp.up_proj
- model.layers.41.mlp.up_proj
- model.layers.27.mlp.up_proj
- model.layers.28.mlp.up_proj
- model.layers.34.mlp.up_proj
# model.norm layers
# post_attention_layernorm layers
- model.layers.0.post_attention_layernorm
- model.layers.1.post_attention_layernorm
- model.layers.2.post_attention_layernorm
- model.layers.3.post_attention_layernorm
- model.layers.4.post_attention_layernorm
- model.layers.5.post_attention_layernorm
- model.layers.6.post_attention_layernorm
- model.layers.7.post_attention_layernorm
- model.layers.8.post_attention_layernorm
- model.layers.9.post_attention_layernorm
- model.layers.10.post_attention_layernorm
- model.layers.11.post_attention_layernorm
- model.layers.12.post_attention_layernorm
- model.layers.13.post_attention_layernorm
- model.layers.14.post_attention_layernorm
- model.layers.15.post_attention_layernorm
- model.layers.16.post_attention_layernorm
- model.layers.17.post_attention_layernorm
- model.layers.18.post_attention_layernorm
- model.layers.19.post_attention_layernorm
- model.layers.20.post_attention_layernorm
- model.layers.21.post_attention_layernorm
- model.layers.22.post_attention_layernorm
- model.layers.23.post_attention_layernorm
# self_attn.k_proj layers
- model.layers.47.self_attn.k_proj
- model.layers.39.self_attn.k_proj
- model.layers.41.self_attn.k_proj
- model.layers.37.self_attn.k_proj
- model.layers.35.self_attn.k_proj
- model.layers.44.self_attn.k_proj
- model.layers.38.self_attn.k_proj
- model.layers.14.self_attn.k_proj
- model.layers.7.self_attn.k_proj
- model.layers.12.self_attn.k_proj
- model.layers.11.self_attn.k_proj
- model.layers.32.self_attn.k_proj
- model.layers.10.self_attn.k_proj
- model.layers.8.self_attn.k_proj
- model.layers.9.self_attn.k_proj
- model.layers.6.self_attn.k_proj
- model.layers.45.self_attn.k_proj
- model.layers.42.self_attn.k_proj
- model.layers.5.self_attn.k_proj
- model.layers.40.self_attn.k_proj
- model.layers.33.self_attn.k_proj
- model.layers.0.self_attn.k_proj
- model.layers.34.self_attn.k_proj
- model.layers.13.self_attn.k_proj
# self_attn.o_proj layers
- model.layers.12.self_attn.o_proj
- model.layers.5.self_attn.o_proj
- model.layers.14.self_attn.o_proj
- model.layers.16.self_attn.o_proj
- model.layers.20.self_attn.o_proj
- model.layers.13.self_attn.o_proj
- model.layers.11.self_attn.o_proj
- model.layers.4.self_attn.o_proj
- model.layers.6.self_attn.o_proj
- model.layers.19.self_attn.o_proj
- model.layers.7.self_attn.o_proj
- model.layers.18.self_attn.o_proj
- model.layers.8.self_attn.o_proj
- model.layers.38.self_attn.o_proj
- model.layers.15.self_attn.o_proj
- model.layers.17.self_attn.o_proj
- model.layers.9.self_attn.o_proj
- model.layers.10.self_attn.o_proj
- model.layers.21.self_attn.o_proj
- model.layers.28.self_attn.o_proj
- model.layers.32.self_attn.o_proj
- model.layers.35.self_attn.o_proj
- model.layers.39.self_attn.o_proj
- model.layers.3.self_attn.o_proj
# self_attn.q_proj layers
- model.layers.1.self_attn.q_proj
- model.layers.2.self_attn.q_proj
- model.layers.3.self_attn.q_proj
- model.layers.44.self_attn.q_proj
- model.layers.29.self_attn.q_proj
- model.layers.45.self_attn.q_proj
- model.layers.43.self_attn.q_proj
- model.layers.32.self_attn.q_proj
- model.layers.38.self_attn.q_proj
- model.layers.19.self_attn.q_proj
- model.layers.42.self_attn.q_proj
- model.layers.34.self_attn.q_proj
- model.layers.36.self_attn.q_proj
- model.layers.40.self_attn.q_proj
- model.layers.26.self_attn.q_proj
- model.layers.20.self_attn.q_proj
- model.layers.39.self_attn.q_proj
- model.layers.28.self_attn.q_proj
- model.layers.35.self_attn.q_proj
- model.layers.41.self_attn.q_proj
- model.layers.33.self_attn.q_proj
- model.layers.25.self_attn.q_proj
- model.layers.30.self_attn.q_proj
- model.layers.27.self_attn.q_proj
# self_attn.v_proj layers
- model.layers.0.self_attn.v_proj
- model.layers.7.self_attn.v_proj
- model.layers.39.self_attn.v_proj
- model.layers.31.self_attn.v_proj
- model.layers.15.self_attn.v_proj
- model.layers.10.self_attn.v_proj
- model.layers.32.self_attn.v_proj
- model.layers.41.self_attn.v_proj
- model.layers.6.self_attn.v_proj
- model.layers.33.self_attn.v_proj
- model.layers.42.self_attn.v_proj
- model.layers.29.self_attn.v_proj
- model.layers.14.self_attn.v_proj
- model.layers.9.self_attn.v_proj
- model.layers.35.self_attn.v_proj
- model.layers.38.self_attn.v_proj
- model.layers.13.self_attn.v_proj
- model.layers.30.self_attn.v_proj
- model.layers.5.self_attn.v_proj
- model.layers.34.self_attn.v_proj
- model.layers.28.self_attn.v_proj
- model.layers.37.self_attn.v_proj
- model.layers.27.self_attn.v_proj
- model.layers.11.self_attn.v_proj
# model.embed_tokens layers
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_torch_fused
lr_scheduler: linear
learning_rate: 5e-6
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
gradient_checkpointing: unsloth
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 2
saves_per_epoch: 1
save_total_limit: 4
debug:
deepspeed: deepspeed_configs/zero3_bf16.json
weight_decay: 0.05
special_tokens:
eos_token: <|im_end|>

View File

@@ -562,8 +562,7 @@ plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
```

View File

@@ -7,8 +7,8 @@ load_in_8bit: true
load_in_4bit: false
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
- path: philschmid/guanaco-sharegpt-style
type: sharegpt
shards: 10
val_set_size: 0
output_dir: temp_debug/axolotl_outputs/model

View File

@@ -35,7 +35,3 @@ RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTHON_VERSION" != "2.5.1" ] ; then \
pip3 install flash-attn==2.6.3; \
fi

View File

@@ -51,12 +51,12 @@ While debugging it's helpful to simplify your test scenario as much as possible.
### Background
The below example shows how to configure VSCode to debug data preprocessing of the `chat_template` format. This is the format used when you have the following in your axolotl config:
The below example shows how to configure VSCode to debug data preprocessing of the `sharegpt` format. This is the format used when you have the following in your axolotl config:
```yaml
datasets:
- path: <path to your chat_template formatted dataset> # example on HF Hub: fozziethebeat/alpaca_messages_2k_test
type: chat_template
- path: <path to your sharegpt formatted dataset> # example on HF Hub: philschmid/guanaco-sharegpt-style
type: sharegpt
```
>[!Important]
@@ -83,7 +83,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely.
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_sharegpt.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
```jsonc
// .vscode/launch.json
@@ -91,12 +91,12 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
"version": "0.2.0",
"configurations": [
{
"name": "Debug axolotl prompt - chat_template",
"name": "Debug axolotl prompt - sharegpt",
"type": "python",
"module": "accelerate.commands.launch",
"request": "launch",
"args": [
"-m", "axolotl.cli.train", "dev_chat_template.yml",
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
// The flags below simplify debugging by overriding the axolotl config
// with the debugging tips above. Modify as needed.
"--dataset_processes=1", // limits data preprocessing to one process
@@ -240,6 +240,6 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
</div>
<br>
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing.
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/sharegpt.yml`, but this is the same thing.
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).

View File

@@ -9,17 +9,14 @@ strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rms_norm: true
liger_glu_activation: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
chat_template: deepseek_v2
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
split: train
dataset_prepared_path: last_run_prepared
val_set_size: 0.0

View File

@@ -11,11 +11,8 @@ chat_template: gemma
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
chat_template: gemma
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
val_set_size: 0.0
output_dir: ./outputs/out

View File

@@ -4,15 +4,11 @@ tokenizer_type: AutoTokenizer
load_in_4bit: true
strict: false
use_tensorboard: true
chat_template: jamba
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
chat_template: jamba
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: jamba-large-fsdp-qlora-ft

View File

@@ -4,7 +4,7 @@ plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
strict: false
@@ -14,10 +14,6 @@ datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./outputs/out

View File

@@ -10,6 +10,7 @@ chat_template: phi_3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
chat_template: phi_3
field_messages: messages
message_field_role: role
message_field_content: content

View File

@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.2
transformers==4.46.1
transformers==4.46.0
tokenizers>=0.20.1
bitsandbytes==0.44.1
accelerate==1.1.0
accelerate==1.0.1
datasets==3.0.1
deepspeed==0.15.3
pydantic==2.6.3
@@ -34,7 +34,7 @@ tensorboard
python-dotenv==1.0.1
autoawq>=0.2.5
triton>=2.3.0
liger-kernel==0.4.0
liger-kernel==0.3.0
mamba-ssm==1.2.0.post1

View File

@@ -272,7 +272,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
chat_template_str = get_chat_template(cfg.chat_template)
model = model.to(cfg.device, dtype=cfg.torch_dtype)

View File

@@ -48,7 +48,6 @@ from trl import (
)
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available
@@ -896,13 +895,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, **kwargs):
def _save_checkpoint(self, model, trial):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
return super()._save_checkpoint(model, trial)
class AxolotlMambaTrainer(AxolotlTrainer):
@@ -1148,12 +1147,6 @@ class TrainerBuilderBase(abc.ABC):
def get_callbacks(self) -> List[TrainerCallback]:
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
@@ -1180,17 +1173,11 @@ class TrainerBuilderBase(abc.ABC):
return callbacks
@abstractmethod
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
)
return callbacks
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
@@ -1236,7 +1223,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
@@ -1608,8 +1595,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template,
tokenizer=self.tokenizer,
self.cfg.chat_template
)
if self.cfg.rl == "orpo":
@@ -1804,7 +1790,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
callbacks = []
return callbacks
def build_training_arguments(self, total_num_steps):
@@ -2013,11 +1999,11 @@ class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks = []
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
callbacks = []
return callbacks
def build(self, total_num_steps):

View File

@@ -18,10 +18,9 @@ Plugins can be used to integrate third-party models, modify the training process
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
"""
import collections
import importlib
import logging
from typing import OrderedDict
from typing import List
class BasePlugin:
@@ -48,7 +47,7 @@ class BasePlugin:
Initializes the BasePlugin.
"""
def register(self, cfg): # pylint: disable=unused-argument
def register(self, cfg):
"""
Registers the plugin with the given configuration.
@@ -64,7 +63,7 @@ class BasePlugin:
Returns a pydantic model for the plugin's input arguments.
"""
def pre_model_load(self, cfg): # pylint: disable=unused-argument
def pre_model_load(self, cfg):
"""
Performs actions before the model is loaded.
@@ -75,7 +74,7 @@ class BasePlugin:
None
"""
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
def post_model_load(self, cfg, model):
"""
Performs actions after the model is loaded.
@@ -87,7 +86,7 @@ class BasePlugin:
None
"""
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
def pre_lora_load(self, cfg, model):
"""
Performs actions before LoRA weights are loaded.
@@ -99,7 +98,7 @@ class BasePlugin:
None
"""
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
def post_lora_load(self, cfg, model):
"""
Performs actions after LoRA weights are loaded.
@@ -111,7 +110,7 @@ class BasePlugin:
None
"""
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
def create_optimizer(self, cfg, trainer):
"""
Creates and returns an optimizer for training.
@@ -123,9 +122,7 @@ class BasePlugin:
object: The created optimizer.
"""
def create_lr_scheduler(
self, cfg, trainer, optimizer
): # pylint: disable=unused-argument
def create_lr_scheduler(self, cfg, trainer, optimizer):
"""
Creates and returns a learning rate scheduler.
@@ -138,7 +135,7 @@ class BasePlugin:
object: The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
def add_callbacks_pre_trainer(self, cfg, model):
"""
Adds callbacks to the trainer before training.
@@ -149,11 +146,8 @@ class BasePlugin:
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
def add_callbacks_post_trainer(
self, cfg, trainer
): # pylint: disable=unused-argument
def add_callbacks_post_trainer(self, cfg, trainer):
"""
Adds callbacks to the trainer after training.
@@ -164,9 +158,8 @@ class BasePlugin:
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
def post_train(self, cfg, model): # pylint: disable=unused-argument
def post_train(self, cfg, model):
"""
Performs actions after training is complete.
@@ -178,7 +171,7 @@ class BasePlugin:
None
"""
def post_train_unload(self, cfg): # pylint: disable=unused-argument
def post_train_unload(self, cfg):
"""
Performs actions after training is complete and the model is unloaded.
@@ -234,7 +227,7 @@ class PluginManager:
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
"""
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
plugins: List[BasePlugin] = []
_instance = None
@@ -244,7 +237,7 @@ class PluginManager:
"""
if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins = collections.OrderedDict()
cls._instance.plugins: List[BasePlugin] = []
return cls._instance
@staticmethod
@@ -272,7 +265,7 @@ class PluginManager:
"""
try:
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
self.plugins.append(plugin)
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
@@ -284,7 +277,7 @@ class PluginManager:
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
"""
input_args = []
for plugin in self.plugins.values():
for plugin in self.plugins:
input_args_from_plugin = plugin.get_input_args()
if input_args_from_plugin is not None:
input_args.append(input_args_from_plugin)
@@ -300,7 +293,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
plugin.pre_model_load(cfg)
def post_model_load(self, cfg, model):
@@ -314,7 +307,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
plugin.post_model_load(cfg, model)
def pre_lora_load(self, cfg, model):
@@ -328,7 +321,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
plugin.pre_lora_load(cfg, model)
def post_lora_load(self, cfg, model):
@@ -342,7 +335,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
plugin.post_lora_load(cfg, model)
def create_optimizer(self, cfg, trainer):
@@ -356,7 +349,7 @@ class PluginManager:
Returns:
object: The created optimizer, or None if none was found.
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
optimizer = plugin.create_optimizer(cfg, trainer)
if optimizer is not None:
return optimizer
@@ -374,7 +367,7 @@ class PluginManager:
Returns:
object: The created learning rate scheduler, or None if none was found.
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
if scheduler is not None:
return scheduler
@@ -392,7 +385,7 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins.values():
for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
return callbacks
@@ -408,7 +401,7 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins.values():
for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
return callbacks
@@ -423,5 +416,5 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
plugin.post_train_unload(cfg)

View File

@@ -18,23 +18,20 @@ Module for the Plugin for LIGER integraton with Axolotl.
Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight.
"""
import inspect
import logging
import sys
from functools import partial
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.liger")
class LigerPlugin(BasePlugin):
"""
@@ -45,31 +42,59 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)
kwargs = {}
if "rope" in liger_fn_sig.parameters:
kwargs["rope"] = cfg.liger_rope
if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs[
"fused_linear_cross_entropy"
] = cfg.liger_fused_linear_cross_entropy
if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters:
kwargs["layer_norm"] = cfg.liger_layer_norm
if "geglu" in liger_fn_sig.parameters:
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
with zero_only():
LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
if cfg.model_config_type == "llama":
from liger_kernel.transformers.model.llama import (
lce_forward as llama_lce_forward,
)
from transformers.models.llama import modeling_llama
if cfg.liger_rope:
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_llama.LlamaRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_llama.LlamaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
elif cfg.liger_fused_linear_cross_entropy:
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
elif cfg.model_config_type == "mistral":
from liger_kernel.transformers.model.mistral import (
lce_forward as mistral_lce_forward,
)
from transformers.models.mistral import modeling_mistral
if cfg.liger_rope:
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_mistral.MistralRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_mistral.MistralMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
elif cfg.model_config_type == "gemma":
from liger_kernel.transformers.model.gemma import (
lce_forward as gemma_lce_forward,
)
from transformers.models.gemma import modeling_gemma
if cfg.liger_rope:
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma.GemmaRMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
apply_liger_fn(**kwargs)
if cfg.liger_swiglu:
modeling_gemma.GemmaMLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba
@@ -79,12 +104,30 @@ class LigerPlugin(BasePlugin):
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
if cfg.liger_swiglu:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
elif cfg.model_config_type == "qwen2":
from liger_kernel.transformers.model.qwen2 import (
lce_forward as qwen2_lce_forward,
)
from transformers.models.qwen2 import modeling_qwen2
if cfg.liger_rope:
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
elif cfg.model_config_type == "deepseek_v2":
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM
@@ -103,9 +146,44 @@ class LigerPlugin(BasePlugin):
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
if cfg.liger_swiglu:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_cross_entropy:
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type == "gemma2":
from transformers.models.gemma2 import modeling_gemma2
if cfg.liger_rope:
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma2.Gemma2RMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
if cfg.liger_swiglu:
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
logging.warning(
"Fused linear cross entropy is not supported for Gemma 2."
)
elif cfg.model_config_type == "phi3":
from liger_kernel.transformers.model.phi3 import (
lce_forward as phi3_lce_forward,
)
from transformers.models.phi3 import modeling_phi3
if cfg.liger_rope:
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward

View File

@@ -15,12 +15,9 @@
"""
Module for handling LIGER input arguments.
"""
import logging
from typing import Optional
from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.liger.args")
from pydantic import BaseModel
class LigerArgs(BaseModel):
@@ -30,24 +27,6 @@ class LigerArgs(BaseModel):
liger_rope: Optional[bool] = None
liger_rms_norm: Optional[bool] = None
liger_layer_norm: Optional[bool] = None
liger_swiglu: Optional[bool] = None
liger_glu_activation: Optional[bool] = None
liger_cross_entropy: Optional[bool] = None
liger_fused_linear_cross_entropy: Optional[bool] = None
@model_validator(mode="before")
@classmethod
def check_deprecated_swiglu(cls, data):
if data.get("liger_swiglu") is not None:
if data.get("liger_glu_activation") is not None:
raise ValueError(
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
)
LOG.warning(
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
"Please use 'liger_glu_activation' instead."
)
data["liger_glu_activation"] = data.pop("liger_swiglu")
return data

View File

@@ -27,15 +27,18 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
]
def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
# def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
if model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
# elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
elif hasattr(transformers, "modeling_flash_attention_utils"):
if not has_remote_code:
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
return

File diff suppressed because one or more lines are too long

View File

@@ -57,7 +57,6 @@ class ChatTemplate(str, Enum):
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
exaone = "exaone" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel):

View File

@@ -2,11 +2,9 @@
import functools
import logging
import time
from pathlib import Path
from typing import List, Optional, Tuple, Union
import requests
from datasets import (
Dataset,
DatasetDict,
@@ -55,28 +53,6 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl")
def retry_on_request_exceptions(max_retries=3, delay=1):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc
return wrapper
return decorator
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_dataset(cfg, tokenizer, processor=None):
prompters = []
if not cfg.pretraining_dataset:

View File

@@ -394,10 +394,15 @@ class ModelLoader:
and self.cfg.flash_attention
and self.cfg.sample_packing
):
has_remote_code = (
"auto_map" in self.model_config
and self.model_type in self.model_config["auto_map"]
)
patch_for_multipack(
self.cfg.model_config_type,
model_name=self.cfg.base_model,
is_remote_code=self.cfg.trust_remote_code,
has_remote_code=has_remote_code,
)
if self.cfg.is_llama_derived_model:
@@ -640,7 +645,9 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
elif self.cfg.adapter == "qlora" and (
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
):
bnb_config = {
"load_in_4bit": True,
"llm_int8_threshold": 6.0,
@@ -663,7 +670,9 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]:
elif self.cfg.adapter == "lora" and (
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
):
bnb_config = {
"load_in_8bit": True,
}
@@ -676,8 +685,10 @@ class ModelLoader:
# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
self.model_kwargs.pop("load_in_8bit", None)
self.model_kwargs.pop("load_in_4bit", None)
if "load_in_8bit" in self.model_kwargs:
del self.model_kwargs["load_in_8bit"]
if "load_in_4bit" in self.model_kwargs:
del self.model_kwargs["load_in_4bit"]
def set_attention_config(self) -> None:
"""
@@ -962,10 +973,17 @@ class ModelLoader:
if is_deepspeed_zero3_enabled():
skip_prepare_model_for_kbit_training = True
is_load_in_8bit = (
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
)
is_load_in_4bit = (
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
)
if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
and (is_load_in_8bit or is_load_in_4bit)
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
self.model = prepare_model_for_kbit_training(
@@ -1103,10 +1121,16 @@ class ModelLoader:
# ---------------------------------------------------------
# put model to accelerator
# ---------------------------------------------------------
is_load_in_8bit = (
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
)
is_load_in_4bit = (
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
)
if (
self.cfg.ddp
and not self.cfg.load_in_8bit
and not (self.cfg.rl and self.cfg.load_in_4bit)
and not is_load_in_8bit
and not (self.cfg.rl and is_load_in_4bit)
and not skip_move_to_device
):
# TODO revaldate this conditional

View File

@@ -1,250 +0,0 @@
from typing import Optional
import torch
from torch import Tensor
from torch.distributed._tensor import DTensor
from torch.optim import Optimizer
from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit
from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit
from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8
class _ShampooBase(Optimizer):
def __init__(
self,
params,
lr=1e-1,
momentum=0.0,
weight_decay=0.0,
eps=1e-4,
update_freq=1,
*,
block_size,
quantization_bits,
optimizer_state_class,
):
if lr <= 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if eps < 0.0:
raise ValueError(f"Invalid eps value: {eps}")
if update_freq < 1:
raise ValueError(f"Invalid update_freq value: {update_freq}")
defaults = dict(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
eps=eps,
update_freq=update_freq,
)
super().__init__(params, defaults)
self.block_size = block_size
self.quantization_bits = quantization_bits
self.optimizer_state_class = optimizer_state_class
def step(self, closure: Optional[callable] = None) -> Optional[float]:
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
state["momentum_buffer"] = self._new_buffer(grad, True)
state["preconds"] = []
state["inv_preconds"] = []
for dim in grad.size():
state["preconds"].append(
self.optimizer_state_class.zeros(
(dim, dim),
signed=False,
block_size=self.block_size,
device=grad.device,
)
)
state["inv_preconds"].append(
torch.zeros((dim, dim), device=grad.device)
)
state["step"] += 1
beta = group["momentum"]
weight_decay = group["weight_decay"]
lr = group["lr"]
eps = group["eps"]
update_freq = group["update_freq"]
# Apply momentum
if beta > 0:
state["momentum_buffer"].mul_(beta).add_(grad, alpha=1 - beta)
grad = state["momentum_buffer"]
# Apply weight decay
if weight_decay > 0:
grad = grad.add(p.data, alpha=weight_decay)
# Preconditioning
order = grad.ndimension()
original_size = grad.size()
for dim_id, dim in enumerate(grad.size()):
precond = state["preconds"][dim_id]
inv_precond = state["inv_preconds"][dim_id]
# Reshape grad
grad = grad.transpose(0, dim_id).contiguous()
transposed_size = grad.size()
grad = grad.view(dim, -1)
grad_t = grad.t()
# Update preconditioner
precond_fp32 = precond.dequantize()
precond_update = grad @ grad_t
precond_fp32.add_(precond_update)
# Quantize preconditioner back
precond.copy_(precond_fp32)
# Update inverse preconditioner
if state["step"] % update_freq == 0:
inv_precond.copy_(
self._compute_inv_precond(precond_fp32, eps, order)
)
# Precondition grad
if dim_id == order - 1:
# Last dimension
grad = grad_t @ inv_precond
grad = grad.view(original_size)
else:
grad = inv_precond @ grad
grad = grad.view(transposed_size)
# Update parameter
p.data.add_(grad, alpha=-lr)
return loss
def _compute_inv_precond(self, precond: Tensor, eps: float, order: int):
# Add eps for numerical stability
precond = precond + torch.eye(precond.size(0), device=precond.device) * eps
# Compute matrix power
inv_precond = self._matrix_power(precond, -1.0 / (2 * order))
return inv_precond
def _matrix_power(self, matrix: Tensor, power: float) -> Tensor:
# Compute matrix power using SVD
u, s, v = torch.svd(matrix)
s_pow = s.pow(power)
return u @ torch.diag(s_pow) @ v.t()
# bring your own function to create zero-filled subclass
@staticmethod
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
raise NotImplementedError
# follow bitsandbytes, only quantize tensors >= 4096 values
# also wrap subclass in DTensor when needed
def _new_buffer(self, p: Tensor, signed: bool):
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
if isinstance(p, DTensor):
out = DTensor.from_local(
local_tensor=self._subclass_zeros(
p.to_local(), signed, self.block_size
),
device_mesh=p.device_mesh,
placements=p.placements,
run_check=False,
)
else:
out = self._subclass_zeros(p, signed, self.block_size)
else:
out = torch.zeros_like(p)
return out
class Shampoo8bit(_ShampooBase):
def __init__(
self,
params,
lr=1e-1,
momentum=0.0,
weight_decay=0.0,
eps=1e-4,
update_freq=1,
*,
block_size=256,
):
super().__init__(
params,
lr,
momentum,
weight_decay,
eps,
update_freq,
block_size=block_size,
quantization_bits=8,
optimizer_state_class=OptimState8bit,
)
class Shampoo4bit(_ShampooBase):
def __init__(
self,
params,
lr=1e-1,
momentum=0.0,
weight_decay=0.0,
eps=1e-4,
update_freq=1,
*,
block_size=128,
):
super().__init__(
params,
lr,
momentum,
weight_decay,
eps,
update_freq,
block_size=block_size,
quantization_bits=4,
optimizer_state_class=OptimState4bit,
)
class ShampooFp8(_ShampooBase):
def __init__(
self,
params,
lr=1e-1,
momentum=0.0,
weight_decay=0.0,
eps=1e-4,
update_freq=1,
*,
block_size=256,
):
super().__init__(
params,
lr,
momentum,
weight_decay,
eps,
update_freq,
block_size=block_size,
quantization_bits=8, # FP8 uses 8 bits
optimizer_state_class=OptimStateFp8,
)

View File

@@ -1,6 +1,7 @@
"""
Simple end-to-end test for Liger integration
"""
import unittest
from pathlib import Path

View File

@@ -14,7 +14,7 @@ from huggingface_hub import snapshot_download
from axolotl.utils.dict import DictDefault
from ..utils import is_hopper, with_temp_dir
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
@@ -59,7 +59,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -116,7 +116,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 50,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -144,146 +144,6 @@ class TestMultiGPULlama(unittest.TestCase):
]
)
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
@with_temp_dir
def test_dpo_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 2048,
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"rl": "dpo",
"chat_template": "llama3",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
},
],
"num_epochs": 1,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_dpo_qlora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM-135M",
"sequence_len": 2048,
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"rl": "dpo",
"chat_template": "chatml",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
},
],
"num_epochs": 1,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_fsdp(self, temp_dir):
# pylint: disable=duplicate-code
@@ -305,7 +165,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -371,7 +231,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -413,6 +273,7 @@ class TestMultiGPULlama(unittest.TestCase):
]
)
@pytest.mark.skip("disabled due to upstream issue")
@with_temp_dir
def test_fsdp_qlora_prequant_packed(self, temp_dir):
# pylint: disable=duplicate-code
@@ -421,7 +282,6 @@ class TestMultiGPULlama(unittest.TestCase):
"base_model": "axolotl-ai-co/TinyLlama_v1.1-bnb-nf4-bf16",
"tokenizer_type": "AutoTokenizer",
"adapter": "qlora",
"mean_resizing_embeddings": True,
"load_in_4bit": True,
"lora_r": 8,
"lora_alpha": 16,
@@ -437,7 +297,7 @@ class TestMultiGPULlama(unittest.TestCase):
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "</s>",
"pad_token": "<|end_of_text|>",
},
"datasets": [
{
@@ -447,7 +307,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -513,7 +373,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -572,7 +432,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,

View File

@@ -47,7 +47,7 @@ class TestMultiGPUQwen2(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 100,
"warmup_steps": 20,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,

View File

@@ -13,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import require_torch_2_3_1, with_temp_dir
from ..utils import require_torch_2_1_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -24,7 +24,7 @@ class Test4dMultipackLlama(unittest.TestCase):
Test case for Llama models using 4d attention with multipack
"""
@require_torch_2_3_1
@require_torch_2_1_1
@with_temp_dir
def test_sdp_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code

View File

@@ -9,8 +9,6 @@ from functools import wraps
from importlib.metadata import version
from pathlib import Path
import torch
def with_temp_dir(test_func):
@wraps(test_func)
@@ -37,18 +35,13 @@ def most_recent_subdir(path):
return subdir
def require_torch_2_3_1(test_case):
def require_torch_2_1_1(test_case):
"""
Decorator marking a test that requires torch >= 2.3.1
Decorator marking a test that requires torch >= 2.1.1
"""
def is_min_2_3_1():
def is_min_2_1_1():
torch_version = version("torch")
return torch_version >= "2.3.1"
return torch_version >= "2.1.1"
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
def is_hopper():
compute_capability = torch.cuda.get_device_capability()
return compute_capability == (9, 0)
return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)

View File

@@ -1,80 +0,0 @@
"""
config validation tests for swiglu args
"""
# pylint: disable=duplicate-code
import logging
from typing import Optional
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="minimal_base_cfg")
def fixture_cfg():
return DictDefault(
{
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
}
)
class BaseValidation:
"""
Base validation module to setup the log capture
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
# pylint: disable=too-many-public-methods
class TestValidation(BaseValidation):
"""
Test the validation module for liger
"""
def test_deprecated_swiglu(self, minimal_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
}
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
updated_cfg = validate_config(test_cfg)
assert (
"The 'liger_swiglu' argument is deprecated"
in self._caplog.records[0].message
)
assert updated_cfg.liger_swiglu is None
assert updated_cfg.liger_glu_activations is False
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
"liger_glu_activations": True,
}
| minimal_cfg
)
with pytest.raises(
ValueError,
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
):
validate_config(test_cfg)

View File

@@ -306,10 +306,6 @@ class TestDatasetPreparation(unittest.TestCase):
"""Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
# make sure prepared_path is empty
shutil.rmtree(prepared_path, ignore_errors=True)
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
@@ -371,44 +367,43 @@ class TestDatasetPreparation(unittest.TestCase):
def test_load_local_hub_with_revision(self):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
with tempfile.TemporaryDirectory() as tmp_dir2:
tmp_ds_path = Path(tmp_dir2) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
f"{tmp_ds_path}/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
if __name__ == "__main__":