Compare commits
5 Commits
e766a730ba
...
b708a1cc45
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b708a1cc45 | ||
|
|
daa9a58f83 | ||
|
|
ae7069e15b | ||
|
|
20d48cd617 | ||
|
|
1447beb132 |
@@ -49,7 +49,8 @@ sections = [
|
||||
("Knowledge Distillation (KD)", "kd"),
|
||||
("Liger Kernels", "liger"),
|
||||
("Language Model Evaluation Harness (LM Eval)", "lm_eval"),
|
||||
("Spectrum", "spectrum")
|
||||
("Spectrum", "spectrum"),
|
||||
("LLMCompressor", "llm_compressor")
|
||||
]
|
||||
|
||||
for section_name, folder_name in sections:
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
---
|
||||
title: "LLMCompressor Sparse Fine-tuning"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
number-sections: true
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
# LLMCompressor Integration
|
||||
|
||||
Fine-tune sparsified models in Axolotl using [LLMCompressor](https://github.com/vllm-project/llm-compressor).
|
||||
|
||||
This integration enables fine-tuning of models **already sparsified** using LLMCompressor.
|
||||
It hooks into Axolotl’s training pipeline using the plugin system and maintains sparsity throughout the fine-tuning process.
|
||||
|
||||
---
|
||||
|
||||
## Requirements
|
||||
|
||||
- Install Axolotl with `llmcompressor` extras:
|
||||
|
||||
```bash
|
||||
pip install "axolotl[llmcompressor]"
|
||||
```
|
||||
|
||||
- Requires `llmcompressor >= 0.5.1`
|
||||
|
||||
This will install all required dependencies for sparse model fine-tuning.
|
||||
|
||||
---
|
||||
|
||||
## Usage
|
||||
|
||||
To enable sparse fine-tuning with this integration, configure your Axolotl YAML like so:
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
|
||||
|
||||
llmcompressor:
|
||||
recipe:
|
||||
finetuning_stage:
|
||||
finetuning_modifiers:
|
||||
ConstantPruningModifier:
|
||||
targets: [
|
||||
're:.*q_proj.weight',
|
||||
're:.*k_proj.weight',
|
||||
're:.*v_proj.weight',
|
||||
're:.*o_proj.weight',
|
||||
're:.*gate_proj.weight',
|
||||
're:.*up_proj.weight',
|
||||
're:.*down_proj.weight',
|
||||
]
|
||||
start: 0
|
||||
# ... (other Axolotl training arguments)
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
This plugin **does not prune or sparsify the model**. It is only meant for **fine-tuning models that are already sparsified**.
|
||||
:::
|
||||
|
||||
---
|
||||
|
||||
## Pre-Sparsified Checkpoints
|
||||
|
||||
You can use:
|
||||
|
||||
- Your own LLMCompressor-sparsified model
|
||||
- Or one from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic)
|
||||
|
||||
Refer to the [LLMCompressor README](https://github.com/vllm-project/llm-compressor/blob/main/README.md) to learn how to sparsify models or write custom recipes.
|
||||
|
||||
---
|
||||
|
||||
## Example Config
|
||||
|
||||
A full working example is provided at:
|
||||
|
||||
```bash
|
||||
examples/llama-3/sparse-finetuning.yaml
|
||||
```
|
||||
|
||||
Run fine-tuning using:
|
||||
|
||||
```bash
|
||||
axolotl train examples/llama-3/sparse-finetuning.yaml
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Learn More
|
||||
|
||||
Explore LLMCompressor capabilities, supported modifiers, and detailed examples:
|
||||
|
||||
👉 [LLMCompressor GitHub](https://github.com/vllm-project/llm-compressor)
|
||||
@@ -45,6 +45,7 @@ llmcompressor:
|
||||
're:.*down_proj.weight',
|
||||
]
|
||||
start: 0
|
||||
save_compressed: true
|
||||
# ... (other training arguments)
|
||||
```
|
||||
|
||||
@@ -52,19 +53,56 @@ This plugin **does not apply pruning or sparsification itself** — it is intend
|
||||
|
||||
Pre-sparsified checkpoints can be:
|
||||
- Generated using [LLMCompressor](https://github.com/vllm-project/llm-compressor)
|
||||
- Or downloaded from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic)
|
||||
- Downloaded from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic)
|
||||
- Any custom LLM with compatible sparsity patterns that you've created yourself
|
||||
|
||||
To learn more about writing and customizing LLMCompressor recipes, refer to the official documentation:
|
||||
[https://github.com/vllm-project/llm-compressor/blob/main/README.md](https://github.com/vllm-project/llm-compressor/blob/main/README.md)
|
||||
|
||||
### Storage Optimization with save_compressed
|
||||
|
||||
Setting `save_compressed: true` in your configuration enables saving models in a compressed format, which:
|
||||
- Reduces disk space usage by approximately 40%
|
||||
- Maintains compatibility with vLLM for accelerated inference
|
||||
- Maintains compatibility with llmcompressor for further optimization (example: quantization)
|
||||
|
||||
This option is highly recommended when working with sparse models to maximize the benefits of model compression.
|
||||
|
||||
### Example Config
|
||||
|
||||
See [`examples/llama-3/sparse-finetuning.yaml`](examples/llama-3/sparse-finetuning.yaml) for a complete example.
|
||||
|
||||
---
|
||||
|
||||
## Inference with vLLM
|
||||
|
||||
After fine-tuning your sparse model, you can leverage vLLM for efficient inference.
|
||||
You can also use LLMCompressor to apply additional quantization to your fine-tuned
|
||||
sparse model before inference for even greater performance benefits.:
|
||||
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
llm = LLM("path/to/your/sparse/model")
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
```
|
||||
|
||||
For more details on vLLM's capabilities and advanced configuration options, see the [official vLLM documentation](https://docs.vllm.ai/).
|
||||
|
||||
## Learn More
|
||||
|
||||
For details on available sparsity and quantization schemes, fine-tuning recipes, and usage examples, visit the official LLMCompressor repository:
|
||||
|
||||
👉 [https://github.com/vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor)
|
||||
[https://github.com/vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor)
|
||||
|
||||
@@ -8,7 +8,7 @@ from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import get_pytorch_version
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists
|
||||
@@ -56,6 +56,7 @@ class TestCutCrossEntropyIntegration:
|
||||
# pylint: disable=redefined-outer-name
|
||||
def test_llama_w_cce(self, min_cfg, temp_dir):
|
||||
cfg = DictDefault(min_cfg)
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
@@ -101,6 +102,7 @@ class TestCutCrossEntropyIntegration:
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
@@ -129,6 +131,7 @@ class TestCutCrossEntropyIntegration:
|
||||
attention_type: True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
|
||||
@@ -5,7 +5,7 @@ Simple end-to-end test for Liger integration
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
|
||||
@@ -54,6 +54,7 @@ class LigerIntegrationTestCase:
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
@@ -100,6 +101,7 @@ class LigerIntegrationTestCase:
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
|
||||
@@ -79,6 +79,7 @@ class TestLLMCompressorIntegration:
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
@@ -60,6 +60,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
"fp16": True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -104,6 +105,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
"fp16": True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
@@ -63,6 +63,7 @@ class TestFalconPatched(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -103,6 +104,7 @@ class TestFalconPatched(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
@@ -67,6 +67,7 @@ class TestFusedLlama(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -11,7 +11,7 @@ import pytest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
@@ -65,6 +65,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -105,6 +106,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_availab
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
@@ -70,6 +70,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -120,6 +121,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
@@ -63,6 +63,7 @@ class TestMistral(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -104,6 +105,7 @@ class TestMistral(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
@@ -60,6 +60,7 @@ class TestMixtral(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -6,7 +6,7 @@ import unittest
|
||||
|
||||
import transformers
|
||||
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
@@ -47,6 +47,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer, inference=False)
|
||||
@@ -79,6 +80,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer, inference=False)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
@@ -63,6 +63,7 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -82,7 +83,7 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
"sample_packing": True,
|
||||
"flash_attention": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"load_in_8bit": False,
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 64,
|
||||
"lora_alpha": 32,
|
||||
@@ -114,6 +115,7 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, most_recent_subdir
|
||||
@@ -68,6 +68,7 @@ class TestResumeLlama:
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -10,7 +10,7 @@ import pytest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, check_tensorboard
|
||||
@@ -72,6 +72,7 @@ class TestUnslothQLoRA:
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -122,6 +123,7 @@ class TestUnslothQLoRA:
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -177,6 +179,7 @@ class TestUnslothQLoRA:
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -102,6 +102,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -109,6 +109,7 @@ class TestLlamaVision(unittest.TestCase):
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -79,7 +79,7 @@ class TestPhi(unittest.TestCase):
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": False,
|
||||
"load_in_8bit": False,
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 64,
|
||||
"lora_alpha": 32,
|
||||
@@ -111,6 +111,7 @@ class TestPhi(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
@@ -57,6 +57,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -11,7 +11,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||
@@ -319,6 +319,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
"num_epochs": 1,
|
||||
}
|
||||
)
|
||||
self.cfg_1 = validate_config(self.cfg_1)
|
||||
normalize_config(self.cfg_1)
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||
|
||||
Reference in New Issue
Block a user