Merge branch 'main' into llmcompressor-sft
This commit is contained in:
@@ -8,7 +8,7 @@ from axolotl.cli.args import TrainerCliArgs
|
|||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils import get_pytorch_version
|
from axolotl.utils import get_pytorch_version
|
||||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists
|
from ..utils import check_model_output_exists
|
||||||
@@ -56,6 +56,7 @@ class TestCutCrossEntropyIntegration:
|
|||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
def test_llama_w_cce(self, min_cfg, temp_dir):
|
def test_llama_w_cce(self, min_cfg, temp_dir):
|
||||||
cfg = DictDefault(min_cfg)
|
cfg = DictDefault(min_cfg)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
prepare_plugins(cfg)
|
prepare_plugins(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
@@ -101,6 +102,7 @@ class TestCutCrossEntropyIntegration:
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
prepare_plugins(cfg)
|
prepare_plugins(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
@@ -129,6 +131,7 @@ class TestCutCrossEntropyIntegration:
|
|||||||
attention_type: True,
|
attention_type: True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
prepare_plugins(cfg)
|
prepare_plugins(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ Simple end-to-end test for Liger integration
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
|
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
|
||||||
@@ -54,6 +54,7 @@ class LigerIntegrationTestCase:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = validate_config(cfg)
|
||||||
prepare_plugins(cfg)
|
prepare_plugins(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
@@ -100,6 +101,7 @@ class LigerIntegrationTestCase:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = validate_config(cfg)
|
||||||
prepare_plugins(cfg)
|
prepare_plugins(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -60,6 +60,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
|||||||
"fp16": True,
|
"fp16": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -104,6 +105,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
|||||||
"fp16": True,
|
"fp16": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -63,6 +63,7 @@ class TestFalconPatched(unittest.TestCase):
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -103,6 +104,7 @@ class TestFalconPatched(unittest.TestCase):
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -67,6 +67,7 @@ class TestFusedLlama(unittest.TestCase):
|
|||||||
cfg.bf16 = True
|
cfg.bf16 = True
|
||||||
else:
|
else:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import pytest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -65,6 +65,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -105,6 +106,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_availab
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -70,6 +70,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
else:
|
else:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -120,6 +121,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -63,6 +63,7 @@ class TestMistral(unittest.TestCase):
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -104,6 +105,7 @@ class TestMistral(unittest.TestCase):
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -60,6 +60,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import unittest
|
|||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
|
|
||||||
@@ -47,6 +47,7 @@ class TestModelPatches(unittest.TestCase):
|
|||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
load_model(cfg, tokenizer, inference=False)
|
load_model(cfg, tokenizer, inference=False)
|
||||||
@@ -79,6 +80,7 @@ class TestModelPatches(unittest.TestCase):
|
|||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
load_model(cfg, tokenizer, inference=False)
|
load_model(cfg, tokenizer, inference=False)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -63,6 +63,7 @@ class TestPhiMultipack(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -82,7 +83,7 @@ class TestPhiMultipack(unittest.TestCase):
|
|||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"load_in_8bit": False,
|
"load_in_4bit": True,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 64,
|
"lora_r": 64,
|
||||||
"lora_alpha": 32,
|
"lora_alpha": 32,
|
||||||
@@ -114,6 +115,7 @@ class TestPhiMultipack(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, most_recent_subdir
|
from ..utils import check_model_output_exists, most_recent_subdir
|
||||||
@@ -68,6 +68,7 @@ class TestResumeLlama:
|
|||||||
cfg.bf16 = True
|
cfg.bf16 = True
|
||||||
else:
|
else:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import pytest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, check_tensorboard
|
from ..utils import check_model_output_exists, check_tensorboard
|
||||||
@@ -72,6 +72,7 @@ class TestUnslothQLoRA:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -122,6 +123,7 @@ class TestUnslothQLoRA:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -177,6 +179,7 @@ class TestUnslothQLoRA:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
|||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -109,6 +109,7 @@ class TestLlamaVision(unittest.TestCase):
|
|||||||
"bf16": True,
|
"bf16": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ class TestPhi(unittest.TestCase):
|
|||||||
"tokenizer_type": "AutoTokenizer",
|
"tokenizer_type": "AutoTokenizer",
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"sample_packing": False,
|
"sample_packing": False,
|
||||||
"load_in_8bit": False,
|
"load_in_4bit": True,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 64,
|
"lora_r": 64,
|
||||||
"lora_alpha": 32,
|
"lora_alpha": 32,
|
||||||
@@ -111,6 +111,7 @@ class TestPhi(unittest.TestCase):
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||||
@@ -57,6 +57,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
|
|||||||
"seed": 42,
|
"seed": 42,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||||
@@ -319,6 +319,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
self.cfg_1 = validate_config(self.cfg_1)
|
||||||
normalize_config(self.cfg_1)
|
normalize_config(self.cfg_1)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||||
|
|||||||
Reference in New Issue
Block a user