Feat: Add bitnet integration (#3634)
* add bitnet * switch to uv * chore: liint --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -311,6 +311,7 @@ website:
|
||||
- docs/dataset_loading.qmd
|
||||
- docs/qat.qmd
|
||||
- docs/quantize.qmd
|
||||
- docs/1_58bit_finetuning.qmd
|
||||
- docs/optimizations.qmd
|
||||
|
||||
- section: "Core Concepts"
|
||||
|
||||
70
docs/1_58bit_finetuning.qmd
Normal file
70
docs/1_58bit_finetuning.qmd
Normal file
@@ -0,0 +1,70 @@
|
||||
---
|
||||
title: "1.58-bit Finetuning"
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-expand: 2
|
||||
toc-depth: 4
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
1.58-bit finetuning allows you to finetune BitNet models when their prequantized weights are provided. In theory, it will be possible to fine-tune any LLM in 1.58bit format but the performance degradation will be dramatic.
|
||||
|
||||
Axolotl supports 1.58-bit finetuning via the [`onebitllms`](https://github.com/tiiuae/onebitllms) library, which replaces standard linear layers with BitNet-compatible counterparts ready to use for training.
|
||||
|
||||
::: {.callout-note}
|
||||
LoRA is not supported for BitNet models
|
||||
:::
|
||||
|
||||
## Installation
|
||||
|
||||
Install the `onebitllms` package before using this feature:
|
||||
|
||||
```bash
|
||||
uv pip install onebitllms
|
||||
```
|
||||
|
||||
Or from source:
|
||||
|
||||
```bash
|
||||
uv pip install git+https://github.com/tiiuae/onebitllms
|
||||
```
|
||||
|
||||
## Supported models
|
||||
|
||||
For now, only `Falcon-E` series of models are supported. Make sure to use their `-prequantized` version:
|
||||
|
||||
```bash
|
||||
tiiuae/Falcon-E-3B-Base-prequantized
|
||||
tiiuae/Falcon-E-1B-Base-prequantized
|
||||
```
|
||||
|
||||
In theory, any other model would 'work' but the performance degradation will be huge. This remains an area of exploration.
|
||||
|
||||
## Configuration
|
||||
|
||||
To enable 1.58-bit finetuning, set the following in your configuration file:
|
||||
|
||||
```yaml
|
||||
base_model: tiiuae/Falcon-E-3B-Base-prequantized # A BitNet-compatible model
|
||||
|
||||
use_onebitllms: true
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
For BitNet models, it is recommended to use a higher learning rate than classic models (usually in the order of magnitude of 10x).
|
||||
:::
|
||||
|
||||
## Considerations after training
|
||||
|
||||
Once your model has been trained with 1.58bit fine-tuning, you can convert the trained model in ternary format using the `onebitllms` CLI:
|
||||
|
||||
```bash
|
||||
onebitllms quantize_to_1bit INPUT_PATH OUTPUT_PATH
|
||||
```
|
||||
|
||||
After that, you can use supported packages such as `llama.cpp` or Apple MLX package to run the trained model.
|
||||
|
||||
## Example Configuration
|
||||
|
||||
You can find example configurations in `examples/falcon-e` which contain one configuration for SFT and one configuration for DPO.
|
||||
@@ -846,6 +846,17 @@ class ModelLoader:
|
||||
else:
|
||||
self.model = self._load_model_from_pretrained(model_loader_class)
|
||||
|
||||
if self.cfg.use_onebitllms:
|
||||
try:
|
||||
from onebitllms import replace_linear_with_bitnet_linear
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"The 'onebitllms' package is required for use_onebitllms. "
|
||||
"Install it with: `uv pip install onebitllms`"
|
||||
) from exc
|
||||
|
||||
self.model = replace_linear_with_bitnet_linear(self.model)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
skip_move_to_device = True
|
||||
|
||||
|
||||
@@ -103,6 +103,12 @@ class ModelInputConfig(BaseModel):
|
||||
default=None,
|
||||
json_schema_extra={"description": "kwargs for model quantization config"},
|
||||
)
|
||||
use_onebitllms: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to use `onebitllms` for 1.58bit training (only for bitnet models)."
|
||||
},
|
||||
)
|
||||
|
||||
@field_validator("trust_remote_code")
|
||||
@classmethod
|
||||
|
||||
@@ -638,6 +638,12 @@ class LoRAValidationMixin:
|
||||
raise ValueError("Fused modules are not supported with LoRA/QLoRA")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_onebitllms_lora(self):
|
||||
if self.use_onebitllms and self.adapter in ["lora", "qlora"]:
|
||||
raise ValueError("LoRA/QLoRA is not supported with use_onebitllms")
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def warn_qlora_zero3_w_use_reentrant(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user