Compare commits
5 Commits
e766a730ba
...
llmcompres
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b708a1cc45 | ||
|
|
daa9a58f83 | ||
|
|
ae7069e15b | ||
|
|
20d48cd617 | ||
|
|
1447beb132 |
@@ -49,7 +49,8 @@ sections = [
|
|||||||
("Knowledge Distillation (KD)", "kd"),
|
("Knowledge Distillation (KD)", "kd"),
|
||||||
("Liger Kernels", "liger"),
|
("Liger Kernels", "liger"),
|
||||||
("Language Model Evaluation Harness (LM Eval)", "lm_eval"),
|
("Language Model Evaluation Harness (LM Eval)", "lm_eval"),
|
||||||
("Spectrum", "spectrum")
|
("Spectrum", "spectrum"),
|
||||||
|
("LLMCompressor", "llm_compressor")
|
||||||
]
|
]
|
||||||
|
|
||||||
for section_name, folder_name in sections:
|
for section_name, folder_name in sections:
|
||||||
|
|||||||
@@ -1,98 +0,0 @@
|
|||||||
---
|
|
||||||
title: "LLMCompressor Sparse Fine-tuning"
|
|
||||||
format:
|
|
||||||
html:
|
|
||||||
toc: true
|
|
||||||
toc-depth: 3
|
|
||||||
number-sections: true
|
|
||||||
execute:
|
|
||||||
enabled: false
|
|
||||||
---
|
|
||||||
|
|
||||||
# LLMCompressor Integration
|
|
||||||
|
|
||||||
Fine-tune sparsified models in Axolotl using [LLMCompressor](https://github.com/vllm-project/llm-compressor).
|
|
||||||
|
|
||||||
This integration enables fine-tuning of models **already sparsified** using LLMCompressor.
|
|
||||||
It hooks into Axolotl’s training pipeline using the plugin system and maintains sparsity throughout the fine-tuning process.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Requirements
|
|
||||||
|
|
||||||
- Install Axolotl with `llmcompressor` extras:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install "axolotl[llmcompressor]"
|
|
||||||
```
|
|
||||||
|
|
||||||
- Requires `llmcompressor >= 0.5.1`
|
|
||||||
|
|
||||||
This will install all required dependencies for sparse model fine-tuning.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
To enable sparse fine-tuning with this integration, configure your Axolotl YAML like so:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
|
|
||||||
|
|
||||||
llmcompressor:
|
|
||||||
recipe:
|
|
||||||
finetuning_stage:
|
|
||||||
finetuning_modifiers:
|
|
||||||
ConstantPruningModifier:
|
|
||||||
targets: [
|
|
||||||
're:.*q_proj.weight',
|
|
||||||
're:.*k_proj.weight',
|
|
||||||
're:.*v_proj.weight',
|
|
||||||
're:.*o_proj.weight',
|
|
||||||
're:.*gate_proj.weight',
|
|
||||||
're:.*up_proj.weight',
|
|
||||||
're:.*down_proj.weight',
|
|
||||||
]
|
|
||||||
start: 0
|
|
||||||
# ... (other Axolotl training arguments)
|
|
||||||
```
|
|
||||||
|
|
||||||
::: {.callout-note}
|
|
||||||
This plugin **does not prune or sparsify the model**. It is only meant for **fine-tuning models that are already sparsified**.
|
|
||||||
:::
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Pre-Sparsified Checkpoints
|
|
||||||
|
|
||||||
You can use:
|
|
||||||
|
|
||||||
- Your own LLMCompressor-sparsified model
|
|
||||||
- Or one from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic)
|
|
||||||
|
|
||||||
Refer to the [LLMCompressor README](https://github.com/vllm-project/llm-compressor/blob/main/README.md) to learn how to sparsify models or write custom recipes.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Example Config
|
|
||||||
|
|
||||||
A full working example is provided at:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
examples/llama-3/sparse-finetuning.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
Run fine-tuning using:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
axolotl train examples/llama-3/sparse-finetuning.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Learn More
|
|
||||||
|
|
||||||
Explore LLMCompressor capabilities, supported modifiers, and detailed examples:
|
|
||||||
|
|
||||||
👉 [LLMCompressor GitHub](https://github.com/vllm-project/llm-compressor)
|
|
||||||
@@ -45,6 +45,7 @@ llmcompressor:
|
|||||||
're:.*down_proj.weight',
|
're:.*down_proj.weight',
|
||||||
]
|
]
|
||||||
start: 0
|
start: 0
|
||||||
|
save_compressed: true
|
||||||
# ... (other training arguments)
|
# ... (other training arguments)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -52,19 +53,56 @@ This plugin **does not apply pruning or sparsification itself** — it is intend
|
|||||||
|
|
||||||
Pre-sparsified checkpoints can be:
|
Pre-sparsified checkpoints can be:
|
||||||
- Generated using [LLMCompressor](https://github.com/vllm-project/llm-compressor)
|
- Generated using [LLMCompressor](https://github.com/vllm-project/llm-compressor)
|
||||||
- Or downloaded from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic)
|
- Downloaded from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic)
|
||||||
|
- Any custom LLM with compatible sparsity patterns that you've created yourself
|
||||||
|
|
||||||
To learn more about writing and customizing LLMCompressor recipes, refer to the official documentation:
|
To learn more about writing and customizing LLMCompressor recipes, refer to the official documentation:
|
||||||
[https://github.com/vllm-project/llm-compressor/blob/main/README.md](https://github.com/vllm-project/llm-compressor/blob/main/README.md)
|
[https://github.com/vllm-project/llm-compressor/blob/main/README.md](https://github.com/vllm-project/llm-compressor/blob/main/README.md)
|
||||||
|
|
||||||
|
### Storage Optimization with save_compressed
|
||||||
|
|
||||||
|
Setting `save_compressed: true` in your configuration enables saving models in a compressed format, which:
|
||||||
|
- Reduces disk space usage by approximately 40%
|
||||||
|
- Maintains compatibility with vLLM for accelerated inference
|
||||||
|
- Maintains compatibility with llmcompressor for further optimization (example: quantization)
|
||||||
|
|
||||||
|
This option is highly recommended when working with sparse models to maximize the benefits of model compression.
|
||||||
|
|
||||||
### Example Config
|
### Example Config
|
||||||
|
|
||||||
See [`examples/llama-3/sparse-finetuning.yaml`](examples/llama-3/sparse-finetuning.yaml) for a complete example.
|
See [`examples/llama-3/sparse-finetuning.yaml`](examples/llama-3/sparse-finetuning.yaml) for a complete example.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Inference with vLLM
|
||||||
|
|
||||||
|
After fine-tuning your sparse model, you can leverage vLLM for efficient inference.
|
||||||
|
You can also use LLMCompressor to apply additional quantization to your fine-tuned
|
||||||
|
sparse model before inference for even greater performance benefits.:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
llm = LLM("path/to/your/sparse/model")
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
```
|
||||||
|
|
||||||
|
For more details on vLLM's capabilities and advanced configuration options, see the [official vLLM documentation](https://docs.vllm.ai/).
|
||||||
|
|
||||||
## Learn More
|
## Learn More
|
||||||
|
|
||||||
For details on available sparsity and quantization schemes, fine-tuning recipes, and usage examples, visit the official LLMCompressor repository:
|
For details on available sparsity and quantization schemes, fine-tuning recipes, and usage examples, visit the official LLMCompressor repository:
|
||||||
|
|
||||||
👉 [https://github.com/vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor)
|
[https://github.com/vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from axolotl.cli.args import TrainerCliArgs
|
|||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils import get_pytorch_version
|
from axolotl.utils import get_pytorch_version
|
||||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists
|
from ..utils import check_model_output_exists
|
||||||
@@ -56,6 +56,7 @@ class TestCutCrossEntropyIntegration:
|
|||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
def test_llama_w_cce(self, min_cfg, temp_dir):
|
def test_llama_w_cce(self, min_cfg, temp_dir):
|
||||||
cfg = DictDefault(min_cfg)
|
cfg = DictDefault(min_cfg)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
prepare_plugins(cfg)
|
prepare_plugins(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
@@ -101,6 +102,7 @@ class TestCutCrossEntropyIntegration:
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
prepare_plugins(cfg)
|
prepare_plugins(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
@@ -129,6 +131,7 @@ class TestCutCrossEntropyIntegration:
|
|||||||
attention_type: True,
|
attention_type: True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
prepare_plugins(cfg)
|
prepare_plugins(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ Simple end-to-end test for Liger integration
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
|
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
|
||||||
@@ -54,6 +54,7 @@ class LigerIntegrationTestCase:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = validate_config(cfg)
|
||||||
prepare_plugins(cfg)
|
prepare_plugins(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
@@ -100,6 +101,7 @@ class LigerIntegrationTestCase:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = validate_config(cfg)
|
||||||
prepare_plugins(cfg)
|
prepare_plugins(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import pytest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
|
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
|
||||||
@@ -79,6 +79,7 @@ class TestLLMCompressorIntegration:
|
|||||||
)
|
)
|
||||||
|
|
||||||
prepare_plugins(cfg)
|
prepare_plugins(cfg)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -60,6 +60,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
|||||||
"fp16": True,
|
"fp16": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -104,6 +105,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
|||||||
"fp16": True,
|
"fp16": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -63,6 +63,7 @@ class TestFalconPatched(unittest.TestCase):
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -103,6 +104,7 @@ class TestFalconPatched(unittest.TestCase):
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -67,6 +67,7 @@ class TestFusedLlama(unittest.TestCase):
|
|||||||
cfg.bf16 = True
|
cfg.bf16 = True
|
||||||
else:
|
else:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import pytest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -65,6 +65,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -105,6 +106,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_availab
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -70,6 +70,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
else:
|
else:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -120,6 +121,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -63,6 +63,7 @@ class TestMistral(unittest.TestCase):
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -104,6 +105,7 @@ class TestMistral(unittest.TestCase):
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -60,6 +60,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import unittest
|
|||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
|
|
||||||
@@ -47,6 +47,7 @@ class TestModelPatches(unittest.TestCase):
|
|||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
load_model(cfg, tokenizer, inference=False)
|
load_model(cfg, tokenizer, inference=False)
|
||||||
@@ -79,6 +80,7 @@ class TestModelPatches(unittest.TestCase):
|
|||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
load_model(cfg, tokenizer, inference=False)
|
load_model(cfg, tokenizer, inference=False)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -63,6 +63,7 @@ class TestPhiMultipack(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -82,7 +83,7 @@ class TestPhiMultipack(unittest.TestCase):
|
|||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"load_in_8bit": False,
|
"load_in_4bit": True,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 64,
|
"lora_r": 64,
|
||||||
"lora_alpha": 32,
|
"lora_alpha": 32,
|
||||||
@@ -114,6 +115,7 @@ class TestPhiMultipack(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, most_recent_subdir
|
from ..utils import check_model_output_exists, most_recent_subdir
|
||||||
@@ -68,6 +68,7 @@ class TestResumeLlama:
|
|||||||
cfg.bf16 = True
|
cfg.bf16 = True
|
||||||
else:
|
else:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import pytest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, check_tensorboard
|
from ..utils import check_model_output_exists, check_tensorboard
|
||||||
@@ -72,6 +72,7 @@ class TestUnslothQLoRA:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -122,6 +123,7 @@ class TestUnslothQLoRA:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -177,6 +179,7 @@ class TestUnslothQLoRA:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
|||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -109,6 +109,7 @@ class TestLlamaVision(unittest.TestCase):
|
|||||||
"bf16": True,
|
"bf16": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ class TestPhi(unittest.TestCase):
|
|||||||
"tokenizer_type": "AutoTokenizer",
|
"tokenizer_type": "AutoTokenizer",
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"sample_packing": False,
|
"sample_packing": False,
|
||||||
"load_in_8bit": False,
|
"load_in_4bit": True,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 64,
|
"lora_r": 64,
|
||||||
"lora_alpha": 32,
|
"lora_alpha": 32,
|
||||||
@@ -111,6 +111,7 @@ class TestPhi(unittest.TestCase):
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||||
@@ -57,6 +57,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
|
|||||||
"seed": 42,
|
"seed": 42,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||||
@@ -319,6 +319,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
self.cfg_1 = validate_config(self.cfg_1)
|
||||||
normalize_config(self.cfg_1)
|
normalize_config(self.cfg_1)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||||
|
|||||||
Reference in New Issue
Block a user