relora: magnitude pruning of the optimizer (#1245)
* magnitude pruning of the optimizer * add alpaca chat template and fix relora patch * fix handling of lora adapter for relora * fix merge and save call * fixes for 8-bit lora merge * save intermediate checkpoint adapters * auto merge * fix eval check * handle relora annealing * fix anneal step logic * chore: lint * misx fix * fix types * Update tests/e2e/test_relora_llama.py * check for safetensors saved from relora
This commit is contained in:
@@ -7,8 +7,6 @@ import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
@@ -63,6 +61,7 @@ class TestMistral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
@@ -103,12 +102,9 @@ class TestMistral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
Reference in New Issue
Block a user