From dd26cc3c0fddd6164633497d1a5416bc7a3c536b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 13 Jan 2025 10:43:29 -0500 Subject: [PATCH 01/16] add helper to verify the correct model output file exists (#2245) * add helper to verify the correct model output file exists * more checks using helper * chore: lint * fix import and relora model check * workaround for trl trainer saves * remove stray print --- src/axolotl/cli/utils.py | 1 - .../integrations/test_cut_cross_entropy.py | 8 +++--- tests/e2e/integrations/test_liger.py | 7 ++--- tests/e2e/patched/test_4d_multipack_llama.py | 7 +++-- tests/e2e/patched/test_fa_xentropy.py | 5 ++-- tests/e2e/patched/test_falcon_samplepack.py | 7 +++-- tests/e2e/patched/test_fused_llama.py | 5 ++-- tests/e2e/patched/test_llama_s2_attention.py | 7 +++-- .../e2e/patched/test_lora_llama_multipack.py | 7 +++-- tests/e2e/patched/test_mistral_samplepack.py | 7 +++-- tests/e2e/patched/test_mixtral_samplepack.py | 7 +++-- tests/e2e/patched/test_phi_multipack.py | 7 +++-- tests/e2e/patched/test_resume.py | 5 ++-- tests/e2e/patched/test_unsloth_qlora.py | 9 +++---- tests/e2e/test_dpo.py | 16 ++++++------ tests/e2e/test_embeddings_lr.py | 7 +++-- tests/e2e/test_falcon.py | 9 +++---- tests/e2e/test_llama.py | 9 ++++--- tests/e2e/test_llama_pretrain.py | 5 ++-- tests/e2e/test_llama_vision.py | 7 +++-- tests/e2e/test_lora_llama.py | 5 ++-- tests/e2e/test_mamba.py | 5 ++-- tests/e2e/test_mistral.py | 7 +++-- tests/e2e/test_mixtral.py | 13 +++++----- tests/e2e/test_optimizers.py | 9 +++---- tests/e2e/test_phi.py | 7 +++-- tests/e2e/test_relora_llama.py | 8 +++--- tests/e2e/test_reward_model_llama.py | 5 ++-- tests/e2e/utils.py | 26 +++++++++++++++++++ 29 files changed, 116 insertions(+), 111 deletions(-) diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index f0e2573f7..85d241b5d 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -27,7 +27,6 @@ def add_options_from_dataclass(config_class: Type[Any]): field_type = next( t for t in get_args(field_type) if not isinstance(t, NoneType) ) - if field_type == bool: field_name = field.name.replace("_", "-") option_name = f"--{field_name}/--no-{field_name}" diff --git a/tests/e2e/integrations/test_cut_cross_entropy.py b/tests/e2e/integrations/test_cut_cross_entropy.py index a74813e3a..6562af176 100644 --- a/tests/e2e/integrations/test_cut_cross_entropy.py +++ b/tests/e2e/integrations/test_cut_cross_entropy.py @@ -2,8 +2,6 @@ Simple end-to-end test for Cut Cross Entropy integration """ -from pathlib import Path - import pytest from axolotl.cli import load_datasets @@ -13,6 +11,8 @@ from axolotl.utils import get_pytorch_version from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.dict import DictDefault +from ..utils import check_model_output_exists + # pylint: disable=duplicate-code @@ -67,7 +67,7 @@ class TestCutCrossEntropyIntegration: train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) else: train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) @pytest.mark.parametrize( "attention_type", @@ -95,4 +95,4 @@ class TestCutCrossEntropyIntegration: train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) else: train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py index ce9299b92..9154bf9b8 100644 --- a/tests/e2e/integrations/test_liger.py +++ b/tests/e2e/integrations/test_liger.py @@ -1,7 +1,6 @@ """ Simple end-to-end test for Liger integration """ -from pathlib import Path from e2e.utils import require_torch_2_4_1 @@ -11,6 +10,8 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.dict import DictDefault +from ..utils import check_model_output_exists + class LigerIntegrationTestCase: """ @@ -60,7 +61,7 @@ class LigerIntegrationTestCase: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) @require_torch_2_4_1 def test_llama_w_flce(self, temp_dir): @@ -105,4 +106,4 @@ class LigerIntegrationTestCase: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py index b0ada9230..08b3bf0da 100644 --- a/tests/e2e/patched/test_4d_multipack_llama.py +++ b/tests/e2e/patched/test_4d_multipack_llama.py @@ -5,7 +5,6 @@ E2E tests for multipack fft llama using 4d attention masks import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import require_torch_2_3_1, with_temp_dir +from ..utils import check_model_output_exists, require_torch_2_3_1, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -67,7 +66,7 @@ class Test4dMultipackLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_torch_lora_packing(self, temp_dir): @@ -111,4 +110,4 @@ class Test4dMultipackLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py index 183843b7b..791d955b2 100644 --- a/tests/e2e/patched/test_fa_xentropy.py +++ b/tests/e2e/patched/test_fa_xentropy.py @@ -4,7 +4,6 @@ E2E tests for lora llama import logging import os -from pathlib import Path import pytest from transformers.utils import is_torch_bf16_gpu_available @@ -15,7 +14,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import check_tensorboard +from ..utils import check_model_output_exists, check_tensorboard LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -82,7 +81,7 @@ class TestFAXentropyLlama: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high" diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py index d9d715103..69516810f 100644 --- a/tests/e2e/patched/test_falcon_samplepack.py +++ b/tests/e2e/patched/test_falcon_samplepack.py @@ -5,7 +5,6 @@ E2E tests for falcon import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -69,7 +68,7 @@ class TestFalconPatched(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft(self, temp_dir): @@ -109,4 +108,4 @@ class TestFalconPatched(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py index 36b7442d9..23a0adfc0 100644 --- a/tests/e2e/patched/test_fused_llama.py +++ b/tests/e2e/patched/test_fused_llama.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path import pytest from transformers.utils import is_torch_bf16_gpu_available @@ -16,7 +15,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -73,4 +72,4 @@ class TestFusedLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index 0f2539daf..d0fdd918a 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -5,7 +5,6 @@ E2E tests for llama w/ S2 attn import logging import os import unittest -from pathlib import Path import pytest @@ -15,7 +14,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -71,7 +70,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_fft_s2_attn(self, temp_dir): @@ -111,4 +110,4 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py index be2f133fb..634e544d2 100644 --- a/tests/e2e/patched/test_lora_llama_multipack.py +++ b/tests/e2e/patched/test_lora_llama_multipack.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path import pytest from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available @@ -16,7 +15,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -76,7 +75,7 @@ class TestLoraLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available") @with_temp_dir @@ -126,4 +125,4 @@ class TestLoraLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py index 6685fb9d5..e93863e09 100644 --- a/tests/e2e/patched/test_mistral_samplepack.py +++ b/tests/e2e/patched/test_mistral_samplepack.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -69,7 +68,7 @@ class TestMistral(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft_packing(self, temp_dir): @@ -110,4 +109,4 @@ class TestMistral(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index 684baaaff..f87c34fd1 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -5,7 +5,6 @@ E2E tests for mixtral import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -66,7 +65,7 @@ class TestMixtral(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft(self, temp_dir): @@ -108,4 +107,4 @@ class TestMixtral(unittest.TestCase): "MixtralFlashAttention2" in model.model.layers[0].self_attn.__class__.__name__ ) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py index 7b5bf92df..852ac7bec 100644 --- a/tests/e2e/patched/test_phi_multipack.py +++ b/tests/e2e/patched/test_phi_multipack.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -69,7 +68,7 @@ class TestPhiMultipack(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_qlora_packed(self, temp_dir): @@ -120,4 +119,4 @@ class TestPhiMultipack(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py index 7d82ea8c3..5639d2eae 100644 --- a/tests/e2e/patched/test_resume.py +++ b/tests/e2e/patched/test_resume.py @@ -6,7 +6,6 @@ import logging import os import re import subprocess -from pathlib import Path from transformers.utils import is_torch_bf16_gpu_available @@ -16,7 +15,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import most_recent_subdir +from ..utils import check_model_output_exists, most_recent_subdir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -83,7 +82,7 @@ class TestResumeLlama: cli_args = TrainerCliArgs() train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) tb_log_path_1 = most_recent_subdir(temp_dir + "/runs") cmd = f"tensorboard --inspect --logdir {tb_log_path_1}" diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index 0c0ee8610..492bc1c23 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -3,7 +3,6 @@ e2e tests for unsloth qlora """ import logging import os -from pathlib import Path import pytest @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import check_tensorboard +from ..utils import check_model_output_exists, check_tensorboard LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -77,7 +76,7 @@ class TestUnslothQLoRA: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" @@ -127,7 +126,7 @@ class TestUnslothQLoRA: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" @@ -182,7 +181,7 @@ class TestUnslothQLoRA: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 4a705922f..f8109373a 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -15,7 +15,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -68,7 +68,7 @@ class TestDPOLlamaLora(unittest.TestCase): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir def test_dpo_nll_lora(self, temp_dir): @@ -113,7 +113,7 @@ class TestDPOLlamaLora(unittest.TestCase): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir def test_dpo_use_weighting(self, temp_dir): @@ -158,7 +158,7 @@ class TestDPOLlamaLora(unittest.TestCase): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @pytest.mark.skip("kto_pair no longer supported in trl") @with_temp_dir @@ -203,7 +203,7 @@ class TestDPOLlamaLora(unittest.TestCase): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir def test_ipo_lora(self, temp_dir): @@ -247,7 +247,7 @@ class TestDPOLlamaLora(unittest.TestCase): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir def test_orpo_lora(self, temp_dir): @@ -294,7 +294,7 @@ class TestDPOLlamaLora(unittest.TestCase): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @pytest.mark.skip(reason="Fix the implementation") @with_temp_dir @@ -358,4 +358,4 @@ class TestDPOLlamaLora(unittest.TestCase): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py index 6e5ebd05f..222d620ae 100644 --- a/tests/e2e/test_embeddings_lr.py +++ b/tests/e2e/test_embeddings_lr.py @@ -5,7 +5,6 @@ E2E tests for llama pretrain import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import check_tensorboard, with_temp_dir +from .utils import check_model_output_exists, check_tensorboard, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -62,7 +61,7 @@ class TestEmbeddingsLrScale(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high" @@ -106,7 +105,7 @@ class TestEmbeddingsLrScale(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high" diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py index c76699a7c..117de6635 100644 --- a/tests/e2e/test_falcon.py +++ b/tests/e2e/test_falcon.py @@ -5,7 +5,6 @@ E2E tests for falcon import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -71,7 +70,7 @@ class TestFalcon(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_lora_added_vocab(self, temp_dir): @@ -124,7 +123,7 @@ class TestFalcon(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft(self, temp_dir): @@ -163,4 +162,4 @@ class TestFalcon(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index 1ce9d60b9..4384bb61e 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -4,7 +4,8 @@ E2E tests for llama import logging import os -from pathlib import Path + +from e2e.utils import check_model_output_exists from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -60,7 +61,7 @@ class TestLlama: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) def test_fix_untrained_tokens(self, temp_dir): # pylint: disable=duplicate-code @@ -103,7 +104,7 @@ class TestLlama: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) def test_batch_flattening(self, temp_dir): # pylint: disable=duplicate-code @@ -142,4 +143,4 @@ class TestLlama: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index 62fb63c47..d13b10659 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -5,7 +5,6 @@ E2E tests for llama pretrain import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -64,4 +63,4 @@ class TestPretrainLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py index 1d583a326..250cf418c 100644 --- a/tests/e2e/test_llama_vision.py +++ b/tests/e2e/test_llama_vision.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -68,7 +67,7 @@ class TestLlamaVision(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_lora_llama_vision_multimodal_dataset(self, temp_dir): @@ -113,4 +112,4 @@ class TestLlamaVision(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index d06be60b9..a7ead64a5 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -65,4 +64,4 @@ class TestLoraLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py index 8755fa4d5..a1fc30862 100644 --- a/tests/e2e/test_mamba.py +++ b/tests/e2e/test_mamba.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path import pytest @@ -15,7 +14,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -65,4 +64,4 @@ class TestMamba(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py index 57d85e51e..2e79fec8d 100644 --- a/tests/e2e/test_mistral.py +++ b/tests/e2e/test_mistral.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path from transformers.utils import is_torch_bf16_gpu_available @@ -15,7 +14,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -69,7 +68,7 @@ class TestMistral(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft(self, temp_dir): @@ -112,4 +111,4 @@ class TestMistral(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py index d4dad14ef..6792d05a6 100644 --- a/tests/e2e/test_mixtral.py +++ b/tests/e2e/test_mixtral.py @@ -5,7 +5,6 @@ E2E tests for mixtral import logging import os import unittest -from pathlib import Path import torch from transformers.utils import is_torch_bf16_gpu_available @@ -16,7 +15,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -79,7 +78,7 @@ class TestMixtral(unittest.TestCase): model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 ) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_qlora_wo_fa2(self, temp_dir): @@ -133,7 +132,7 @@ class TestMixtral(unittest.TestCase): model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 ) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_16bit_lora_w_fa2(self, temp_dir): @@ -190,7 +189,7 @@ class TestMixtral(unittest.TestCase): model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 ) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_16bit_lora_wo_fa2(self, temp_dir): @@ -247,7 +246,7 @@ class TestMixtral(unittest.TestCase): model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 ) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft(self, temp_dir): @@ -287,4 +286,4 @@ class TestMixtral(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index f69d0500f..f1bbaafd5 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -5,7 +5,6 @@ E2E tests for custom optimizers using Llama import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import require_torch_2_5_1, with_temp_dir +from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -65,7 +64,7 @@ class TestCustomOptimizers(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir @require_torch_2_5_1 @@ -109,7 +108,7 @@ class TestCustomOptimizers(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_fft_schedule_free_adamw(self, temp_dir): @@ -145,4 +144,4 @@ class TestCustomOptimizers(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index 4cc6bcdcc..7a08d0c6f 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -67,7 +66,7 @@ class TestPhi(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_phi_qlora(self, temp_dir): @@ -116,4 +115,4 @@ class TestPhi(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_relora_llama.py b/tests/e2e/test_relora_llama.py index 84582896d..fef6a3d30 100644 --- a/tests/e2e/test_relora_llama.py +++ b/tests/e2e/test_relora_llama.py @@ -13,7 +13,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import check_tensorboard, with_temp_dir +from .utils import check_model_output_exists, check_tensorboard, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -78,10 +78,10 @@ class TestReLoraLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg) assert ( - Path(temp_dir) / "checkpoint-100/adapter/adapter_model.safetensors" - ).exists() - assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists() + Path(temp_dir) / "checkpoint-100/relora/model.safetensors" + ).exists(), "Relora model checkpoint not found" check_tensorboard( temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high" diff --git a/tests/e2e/test_reward_model_llama.py b/tests/e2e/test_reward_model_llama.py index 27ac3e25f..c4cb705ea 100644 --- a/tests/e2e/test_reward_model_llama.py +++ b/tests/e2e/test_reward_model_llama.py @@ -5,7 +5,6 @@ E2E tests for reward model lora llama import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -71,4 +70,4 @@ class TestRewardModelLoraLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 1e05c32c4..759d59659 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -14,6 +14,8 @@ import torch from packaging import version from tbparse import SummaryReader +from axolotl.utils.dict import DictDefault + def with_temp_dir(test_func): @wraps(test_func) @@ -93,3 +95,27 @@ def check_tensorboard( df = reader.scalars # pylint: disable=invalid-name df = df[(df.tag == tag)] # pylint: disable=invalid-name assert df.value.values[-1] < lt_val, assertion_err + + +def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None: + """ + helper function to check if a model output file exists after training + + checks based on adapter or not and if safetensors saves are enabled or not + """ + + if cfg.save_safetensors: + if not cfg.adapter: + assert (Path(temp_dir) / "model.safetensors").exists() + else: + assert (Path(temp_dir) / "adapter_model.safetensors").exists() + else: + # check for both, b/c in trl, it often defaults to saving safetensors + if not cfg.adapter: + assert (Path(temp_dir) / "pytorch_model.bin").exists() or ( + Path(temp_dir) / "model.safetensors" + ).exists() + else: + assert (Path(temp_dir) / "adapter_model.bin").exists() or ( + Path(temp_dir) / "adapter_model.safetensors" + ).exists() From bc1c9c20e3cadb2a60163fa73370161a806124f0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 13 Jan 2025 10:44:11 -0500 Subject: [PATCH 02/16] assume empty lora dropout means 0.0 and add tests (#2243) * assume empty lora dropout means 0.0 and add tests * remove un-necessary arg * refactor based on pr feedback: * chore: lint --- .../config/models/input/v0_4_1/__init__.py | 7 ++ tests/test_lora.py | 69 +++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 tests/test_lora.py diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index bb88a0baa..19ce7b18c 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -367,6 +367,13 @@ class LoraConfig(BaseModel): loraplus_lr_embedding = float(loraplus_lr_embedding) return loraplus_lr_embedding + @model_validator(mode="before") + @classmethod + def validate_lora_dropout(cls, data): + if data.get("adapter") is not None and data.get("lora_dropout") is None: + data["lora_dropout"] = 0.0 + return data + class ReLoRAConfig(BaseModel): """ReLoRA configuration subset""" diff --git a/tests/test_lora.py b/tests/test_lora.py new file mode 100644 index 000000000..b917ff3f9 --- /dev/null +++ b/tests/test_lora.py @@ -0,0 +1,69 @@ +""" +tests for loading loras +""" +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer + +# pylint: disable=duplicate-code +minimal_config = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "learning_rate": 0.000001, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + } + ], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + } +) + + +class TestLoRALoad: + """ + Test class for loading LoRA weights + """ + + def test_load_lora_weights(self): + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.0, + "lora_target_linear": True, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "sequence_len": 1024, + } + | minimal_config + ) + cfg = validate_config(cfg) + normalize_config(cfg) + tokenizer = load_tokenizer(cfg) + load_model(cfg, tokenizer) + + def test_load_lora_weights_empty_dropout(self): + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": None, + "lora_target_linear": True, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "sequence_len": 1024, + } + | minimal_config + ) + cfg = validate_config(cfg) + normalize_config(cfg) + assert cfg.lora_dropout == 0.0 + tokenizer = load_tokenizer(cfg) + load_model(cfg, tokenizer) From f89e9621191f4460afe4425b3785dd8ca0482a9c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 13 Jan 2025 10:44:45 -0500 Subject: [PATCH 03/16] skip over rows in pretraining dataset (#2223) * skip over rows in pretraining dataset * update docs --- docs/dataset-formats/pretraining.qmd | 9 ++++++++- .../utils/config/models/input/v0_4_1/__init__.py | 1 + src/axolotl/utils/data/sft.py | 10 +++++++--- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/docs/dataset-formats/pretraining.qmd b/docs/dataset-formats/pretraining.qmd index bb591328e..600fb63e0 100644 --- a/docs/dataset-formats/pretraining.qmd +++ b/docs/dataset-formats/pretraining.qmd @@ -19,7 +19,14 @@ For pretraining, there is no prompt template or roles. The only required field Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming: ```{.yaml filename="config.yaml"} -pretraining_dataset: # hf path only +pretraining_dataset: + - name: + path: + split: + text_column: # column in dataset with the data, usually `text` + type: pretrain + trust_remote_code: + skip: # number of rows of data to skip over from the beginning ... ``` diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 19ce7b18c..4f368994a 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -129,6 +129,7 @@ class PretrainingDataset(BaseModel): type: Optional[str] = "pretrain" trust_remote_code: Optional[bool] = False data_files: Optional[str] = None + skip: Optional[int] = None class UserDefinedPrompterType(BaseModel): diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index cfc40406e..aff047675 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -89,11 +89,13 @@ def prepare_dataset(cfg, tokenizer, processor=None): split = "train" name = None data_files = None + skip = 0 if isinstance(cfg.pretraining_dataset, list) and isinstance( cfg.pretraining_dataset[0], dict ): path = cfg.pretraining_dataset[0]["path"] name = cfg.pretraining_dataset[0]["name"] + skip = cfg.pretraining_dataset[0]["skip"] if "split" in cfg.pretraining_dataset[0]: split = cfg.pretraining_dataset[0]["split"] @@ -107,10 +109,12 @@ def prepare_dataset(cfg, tokenizer, processor=None): cfg.pretraining_dataset[0]["type"] or "pretrain", ) + iter_ds = load_dataset(path, streaming=True, split=split, name=name, data_files=data_files) + if skip: + LOG.info(f"Skipping {skip} samples from the dataset") + iter_ds = iter_ds.skip(skip) train_dataset = wrap_pretraining_dataset( - load_dataset( - path, streaming=True, split=split, name=name, data_files=data_files - ), + iter_ds, tokenizer, cfg, ds_wrapper_partial, From 1ed4de73b615fe9d943906e0bf5429b1466cb77e Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 13 Jan 2025 12:55:29 -0500 Subject: [PATCH 04/16] CLI cleanup and documentation (#2244) * CLI init refactor * fix * cleanup and (partial) docs * Adding documentation and continuing cleanup (in progress) * remove finetune.py script * continued cleanup and documentation * pytest fixes * review comments * fix * Fix * typing fixes * make sure the batch dataset patcher for multipack is always loaded when handling datasets * review comments * fix --------- Co-authored-by: Dan Saunders Co-authored-by: Wing Lian --- scripts/finetune.py | 52 -- src/axolotl/cli/__init__.py | 565 +----------------- src/axolotl/cli/args.py | 43 ++ src/axolotl/cli/art.py | 23 + src/axolotl/cli/checks.py | 50 ++ src/axolotl/cli/config.py | 217 +++++++ src/axolotl/cli/evaluate.py | 46 +- src/axolotl/cli/inference.py | 259 +++++++- src/axolotl/cli/main.py | 180 +++--- src/axolotl/cli/merge_lora.py | 62 +- src/axolotl/cli/merge_sharded_fsdp_weights.py | 49 +- src/axolotl/cli/preprocess.py | 73 ++- src/axolotl/cli/shard.py | 45 -- src/axolotl/cli/train.py | 65 +- src/axolotl/cli/utils.py | 156 ++++- src/axolotl/common/cli.py | 69 --- src/axolotl/common/datasets.py | 140 +++++ src/axolotl/evaluate.py | 14 +- src/axolotl/train.py | 29 +- src/axolotl/utils/data/sft.py | 4 +- tests/cli/conftest.py | 1 + tests/cli/test_cli_fetch.py | 1 + tests/cli/test_cli_inference.py | 1 + tests/cli/test_cli_interface.py | 1 + tests/cli/test_cli_merge_lora.py | 1 + .../test_cli_merge_sharded_fsdp_weights.py | 44 +- tests/cli/test_cli_preprocess.py | 1 + tests/cli/test_cli_shard.py | 76 --- tests/cli/test_cli_version.py | 1 + tests/cli/test_utils.py | 1 + .../integrations/test_cut_cross_entropy.py | 12 +- tests/e2e/integrations/test_liger.py | 8 +- tests/e2e/patched/test_4d_multipack_llama.py | 8 +- tests/e2e/patched/test_cli_integrations.py | 2 +- tests/e2e/patched/test_fa_xentropy.py | 6 +- tests/e2e/patched/test_falcon_samplepack.py | 8 +- tests/e2e/patched/test_fused_llama.py | 6 +- tests/e2e/patched/test_llama_s2_attention.py | 8 +- .../e2e/patched/test_lora_llama_multipack.py | 8 +- tests/e2e/patched/test_mistral_samplepack.py | 8 +- tests/e2e/patched/test_mixtral_samplepack.py | 8 +- tests/e2e/patched/test_model_patches.py | 7 +- tests/e2e/patched/test_phi_multipack.py | 8 +- tests/e2e/patched/test_resume.py | 8 +- tests/e2e/patched/test_unsloth_qlora.py | 10 +- tests/e2e/test_dpo.py | 32 +- tests/e2e/test_embeddings_lr.py | 8 +- tests/e2e/test_falcon.py | 10 +- tests/e2e/test_llama.py | 10 +- tests/e2e/test_llama_pretrain.py | 6 +- tests/e2e/test_llama_vision.py | 8 +- tests/e2e/test_lora_llama.py | 6 +- tests/e2e/test_mamba.py | 6 +- tests/e2e/test_mistral.py | 8 +- tests/e2e/test_mixtral.py | 14 +- tests/e2e/test_optimizers.py | 10 +- tests/e2e/test_packing_loss.py | 6 +- tests/e2e/test_phi.py | 8 +- tests/e2e/test_relora_llama.py | 6 +- tests/e2e/test_reward_model_llama.py | 6 +- 60 files changed, 1269 insertions(+), 1259 deletions(-) delete mode 100644 scripts/finetune.py create mode 100644 src/axolotl/cli/args.py create mode 100644 src/axolotl/cli/art.py create mode 100644 src/axolotl/cli/checks.py create mode 100644 src/axolotl/cli/config.py delete mode 100644 src/axolotl/cli/shard.py delete mode 100644 src/axolotl/common/cli.py create mode 100644 src/axolotl/common/datasets.py delete mode 100644 tests/cli/test_cli_shard.py diff --git a/scripts/finetune.py b/scripts/finetune.py deleted file mode 100644 index d5bbcaf8f..000000000 --- a/scripts/finetune.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" -import logging -from pathlib import Path - -import fire -import transformers - -from axolotl.cli import ( - check_accelerate_default_config, - check_user_token, - do_inference, - do_merge_lora, - load_cfg, - load_datasets, - print_axolotl_text_art, -) -from axolotl.cli.shard import shard -from axolotl.common.cli import TrainerCliArgs -from axolotl.train import train - -LOG = logging.getLogger("axolotl.scripts.finetune") - - -def do_cli(config: Path = Path("examples/"), **kwargs): - print_axolotl_text_art() - LOG.warning( - str( - PendingDeprecationWarning( - "scripts/finetune.py will be replaced with calling axolotl.cli.train" - ) - ) - ) - parsed_cfg = load_cfg(config, **kwargs) - check_accelerate_default_config() - check_user_token() - parser = transformers.HfArgumentParser((TrainerCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - if parsed_cli_args.inference: - do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) - elif parsed_cli_args.merge_lora: - do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) - elif parsed_cli_args.shard: - shard(cfg=parsed_cfg, cli_args=parsed_cli_args) - else: - dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) - train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) - - -if __name__ == "__main__": - fire.Fire(do_cli) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index d07b10ce3..b20e4f085 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -1,568 +1,5 @@ -"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" +"""Axolotl CLI module initialization.""" -import importlib -import json -import logging -import math import os -import random -import sys -import tempfile -from pathlib import Path -from threading import Thread -from typing import Any, Dict, List, Optional, Union -from urllib.parse import urlparse - -import requests -import torch -import yaml - -# add src to the pythonpath so we don't need to pip install this -from accelerate.commands.config import config_args -from art import text2art -from huggingface_hub import HfApi -from huggingface_hub.utils import LocalTokenNotFoundError -from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer -from transformers.utils import is_torch_bf16_gpu_available -from transformers.utils.import_utils import _is_package_available - -from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer -from axolotl.logging_config import configure_logging -from axolotl.train import TrainDatasetMeta -from axolotl.utils.chat_templates import ( - get_chat_template, - get_chat_template_from_config, -) -from axolotl.utils.comet_ import setup_comet_env_vars -from axolotl.utils.config import ( - normalize_cfg_datasets, - normalize_config, - prepare_plugins, - validate_config, -) -from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset -from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_main_process -from axolotl.utils.mlflow_ import setup_mlflow_env_vars -from axolotl.utils.models import load_processor, load_tokenizer -from axolotl.utils.tokenization import check_dataset_labels -from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env -from axolotl.utils.wandb_ import setup_wandb_env_vars - -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -src_dir = os.path.join(project_root, "src") -sys.path.insert(0, src_dir) - -configure_logging() -LOG = logging.getLogger("axolotl.scripts") os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" - -AXOLOTL_LOGO = """ - #@@ #@@ @@# @@# - @@ @@ @@ @@ =@@# @@ #@ =@@#. - @@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@ - #@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@ - @@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@ - @@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@ - @@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@ - =@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@ - @@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@ - =@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@ - @@@@ @@@@@@@@@@@@@@@@ -""" - - -def print_legacy_axolotl_text_art(suffix=None): - font = "nancyj" - ascii_text = " axolotl" - if suffix: - ascii_text += f" x {suffix}" - ascii_art = text2art(ascii_text, font=font) - - if is_main_process(): - print(ascii_art) - - print_dep_versions() - - -def print_axolotl_text_art( - **kwargs, # pylint: disable=unused-argument -): - if is_main_process(): - print(AXOLOTL_LOGO) - - -def print_dep_versions(): - packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"] - max_len = max(len(pkg) for pkg in packages) - if is_main_process(): - print("*" * 40) - print("**** Axolotl Dependency Versions *****") - for pkg in packages: - pkg_version = _is_package_available(pkg, return_version=True) - print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}") - print("*" * 40) - - -def check_remote_config(config: Union[str, Path]): - # Check if the config is a valid HTTPS URL to a .yml or .yaml file - if not (isinstance(config, str) and config.startswith("https://")): - return config # Return the original value if it's not a valid URL - - filename = os.path.basename(urlparse(config).path) - temp_dir = tempfile.mkdtemp() - - try: - response = requests.get(config, timeout=30) - response.raise_for_status() # Check for HTTP errors - - content = response.content - try: - # Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML - json.loads(content) - # Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link - LOG.warning( - f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended." - ) - except json.JSONDecodeError: - # If it's not valid JSON, verify it's valid YAML - try: - yaml.safe_load(content) - except yaml.YAMLError as err: - raise ValueError( - f"Failed to parse the content at {config} as YAML: {err}" - ) from err - - # Write the content to a file if it's valid YAML (or JSON treated as YAML) - output_path = Path(temp_dir) / filename - with open(output_path, "wb") as file: - file.write(content) - LOG.info( - f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n" - ) - return output_path - - except requests.RequestException as err: - # This catches all requests-related exceptions including HTTPError - raise RuntimeError(f"Failed to download {config}: {err}") from err - except Exception as err: - # Catch-all for any other exceptions - raise err - - -def get_multi_line_input() -> Optional[str]: - print("Give me an instruction (Ctrl + D to submit): ") - instruction = "" - for line in sys.stdin: - instruction += line # pylint: disable=consider-using-join - # instruction = pathlib.Path("/proc/self/fd/0").read_text() - return instruction - - -def do_merge_lora( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - safe_serialization = cfg.save_safetensors is True - - LOG.info("running merge of LoRA with base model") - model = model.merge_and_unload(progressbar=True) - try: - model.to(dtype=cfg.torch_dtype) - except RuntimeError: - pass - model.generation_config.do_sample = True - - if cfg.local_rank == 0: - LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}") - model.save_pretrained( - str(Path(cfg.output_dir) / "merged"), - safe_serialization=safe_serialization, - progressbar=True, - ) - tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) - - -def do_inference( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - prompter = cli_args.prompter - - prompter_module = None - chat_template_str = None - if prompter: - prompter_module = getattr( - importlib.import_module("axolotl.prompters"), prompter - ) - elif cfg.chat_template: - chat_template_str = get_chat_template(cfg.chat_template) - elif cfg.datasets[0].type == "chat_template": - chat_template_str = get_chat_template_from_config( - cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer - ) - - model = model.to(cfg.device, dtype=cfg.torch_dtype) - - while True: - print("=" * 80) - # support for multiline inputs - instruction = get_multi_line_input() - if not instruction: - return - - if prompter_module: - prompt: str = next( - prompter_module().build_prompt(instruction=instruction.strip("\n")) - ) - else: - prompt = instruction.strip() - - if chat_template_str: - batch = tokenizer.apply_chat_template( - [ - { - "role": "user", - "content": prompt, - } - ], - return_tensors="pt", - add_special_tokens=True, - add_generation_prompt=True, - chat_template=chat_template_str, - tokenize=True, - return_dict=True, - ) - else: - batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) - - print("=" * 40) - model.eval() - with torch.no_grad(): - generation_config = GenerationConfig( - repetition_penalty=1.1, - max_new_tokens=1024, - temperature=0.9, - top_p=0.95, - top_k=40, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - do_sample=True, - use_cache=True, - return_dict_in_generate=True, - output_attentions=False, - output_hidden_states=False, - output_scores=False, - ) - streamer = TextStreamer(tokenizer) - generated = model.generate( - inputs=batch["input_ids"].to(cfg.device), - generation_config=generation_config, - streamer=streamer, - ) - print("=" * 40) - print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) - - -def do_inference_gradio( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - import gradio as gr - - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - prompter = cli_args.prompter - - prompter_module = None - chat_template_str = None - if prompter: - prompter_module = getattr( - importlib.import_module("axolotl.prompters"), prompter - ) - elif cfg.chat_template: - chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer) - - model = model.to(cfg.device, dtype=cfg.torch_dtype) - - def generate(instruction): - if not instruction: - return - if prompter_module: - # pylint: disable=stop-iteration-return - prompt: str = next( - prompter_module().build_prompt(instruction=instruction.strip("\n")) - ) - else: - prompt = instruction.strip() - - if chat_template_str: - batch = tokenizer.apply_chat_template( - [ - { - "role": "user", - "content": prompt, - } - ], - return_tensors="pt", - add_special_tokens=True, - add_generation_prompt=True, - chat_template=chat_template_str, - tokenize=True, - return_dict=True, - ) - else: - batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) - - model.eval() - with torch.no_grad(): - generation_config = GenerationConfig( - repetition_penalty=1.1, - max_new_tokens=cfg.get("gradio_max_new_tokens", 1024), - temperature=cfg.get("gradio_temperature", 0.9), - top_p=0.95, - top_k=40, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - do_sample=True, - use_cache=True, - return_dict_in_generate=True, - output_attentions=False, - output_hidden_states=False, - output_scores=False, - ) - streamer = TextIteratorStreamer(tokenizer) - generation_kwargs = { - "inputs": batch["input_ids"].to(cfg.device), - "attention_mask": batch["attention_mask"].to(cfg.device), - "generation_config": generation_config, - "streamer": streamer, - } - - thread = Thread(target=model.generate, kwargs=generation_kwargs) - thread.start() - - all_text = "" - - for new_text in streamer: - all_text += new_text - yield all_text - - demo = gr.Interface( - fn=generate, - inputs="textbox", - outputs="text", - title=cfg.get("gradio_title", "Axolotl Gradio Interface"), - ) - - demo.queue().launch( - show_api=False, - share=cfg.get("gradio_share", True), - server_name=cfg.get("gradio_server_name", "127.0.0.1"), - server_port=cfg.get("gradio_server_port", None), - ) - - -def choose_config(path: Path): - yaml_files = list(path.glob("*.yml")) - - if not yaml_files: - raise ValueError( - "No YAML config files found in the specified directory. Are you using a .yml extension?" - ) - - if len(yaml_files) == 1: - print(f"Using default YAML file '{yaml_files[0]}'") - return str(yaml_files[0]) - - print("Choose a YAML file:") - for idx, file in enumerate(yaml_files): - print(f"{idx + 1}. {file}") - - chosen_file = None - while chosen_file is None: - try: - choice = int(input("Enter the number of your choice: ")) - if 1 <= choice <= len(yaml_files): - chosen_file = str(yaml_files[choice - 1]) - else: - print("Invalid choice. Please choose a number from the list.") - except ValueError: - print("Invalid input. Please enter a number.") - - return chosen_file - - -def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool: - return not any(el in list2 for el in list1) - - -def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): - config = check_remote_config(config) - if Path(config).is_dir(): - config = choose_config(Path(config)) - - # load the config from the yaml file - with open(config, encoding="utf-8") as file: - cfg: DictDefault = DictDefault(yaml.safe_load(file)) - # if there are any options passed in the cli, if it is something that seems valid from the yaml, - # then overwrite the value - cfg_keys = cfg.keys() - for k, _ in kwargs.items(): - # if not strict, allow writing to cfg even if it's not in the yml already - if k in cfg_keys or not cfg.strict: - # handle booleans - if isinstance(cfg[k], bool): - cfg[k] = bool(kwargs[k]) - else: - cfg[k] = kwargs[k] - - cfg.axolotl_config_path = config - - try: - device_props = torch.cuda.get_device_properties("cuda") - gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) - except: # pylint: disable=bare-except # noqa: E722 - gpu_version = None - - prepare_plugins(cfg) - - cfg = validate_config( - cfg, - capabilities={ - "bf16": is_torch_bf16_gpu_available(), - "n_gpu": int(os.environ.get("WORLD_SIZE", 1)), - "compute_capability": gpu_version, - }, - env_capabilities={ - "torch_version": str(torch.__version__).split("+", maxsplit=1)[0], - }, - ) - - prepare_optim_env(cfg) - - prepare_opinionated_env(cfg) - - normalize_config(cfg) - - normalize_cfg_datasets(cfg) - - setup_wandb_env_vars(cfg) - - setup_mlflow_env_vars(cfg) - - setup_comet_env_vars(cfg) - - return cfg - - -def load_datasets( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -) -> TrainDatasetMeta: - tokenizer = load_tokenizer(cfg) - processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None - - train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( - cfg, - tokenizer, - processor=processor, - ) - - if ( - cli_args.debug - or cfg.debug - or cli_args.debug_text_only - or int(cli_args.debug_num_examples) > 0 - ): - LOG.info("check_dataset_labels...") - check_dataset_labels( - train_dataset.select( - [ - random.randrange(0, len(train_dataset) - 1) # nosec - for _ in range(cli_args.debug_num_examples) - ] - ), - tokenizer, - num_examples=cli_args.debug_num_examples, - text_only=cli_args.debug_text_only, - ) - - LOG.info("printing prompters...") - for prompter in prompters: - LOG.info(prompter) - - return TrainDatasetMeta( - train_dataset=train_dataset, - eval_dataset=eval_dataset, - total_num_steps=total_num_steps, - ) - - -def load_rl_datasets( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, # pylint: disable=unused-argument -) -> TrainDatasetMeta: - train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) - total_num_steps = int( - math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) - ) - - if cli_args.debug or cfg.debug: - LOG.info("check_dataset_labels...") - - tokenizer = load_tokenizer(cfg) - check_dataset_labels( - train_dataset.select( - [ - random.randrange(0, len(train_dataset) - 1) # nosec - for _ in range(cli_args.debug_num_examples) - ] - ), - tokenizer, - num_examples=cli_args.debug_num_examples, - text_only=cli_args.debug_text_only, - rl_mode=True, - ) - - return TrainDatasetMeta( - train_dataset=train_dataset, - eval_dataset=eval_dataset, - total_num_steps=total_num_steps, - ) - - -def check_accelerate_default_config(): - if Path(config_args.default_yaml_config_file).exists(): - LOG.warning( - f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors" - ) - - -def check_user_token(): - # Skip check if HF_HUB_OFFLINE is set to True - if os.getenv("HF_HUB_OFFLINE") == "1": - LOG.info( - "Skipping HuggingFace token verification because HF_HUB_OFFLINE is set to True. Only local files will be used." - ) - return True - - # Verify if token is valid - api = HfApi() - try: - user_info = api.whoami() - return bool(user_info) - except LocalTokenNotFoundError: - LOG.warning( - "Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets." - ) - return False diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py new file mode 100644 index 000000000..0618e07f1 --- /dev/null +++ b/src/axolotl/cli/args.py @@ -0,0 +1,43 @@ +"""Module for axolotl CLI command arguments.""" + +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class PreprocessCliArgs: + """Dataclass with CLI arguments for `axolotl preprocess` command.""" + + debug: bool = field(default=False) + debug_text_only: bool = field(default=False) + debug_num_examples: int = field(default=1) + prompter: Optional[str] = field(default=None) + download: Optional[bool] = field(default=True) + + +@dataclass +class TrainerCliArgs: + """Dataclass with CLI arguments for `axolotl train` command.""" + + debug: bool = field(default=False) + debug_text_only: bool = field(default=False) + debug_num_examples: int = field(default=0) + merge_lora: bool = field(default=False) + prompter: Optional[str] = field(default=None) + shard: bool = field(default=False) + + +@dataclass +class EvaluateCliArgs: + """Dataclass with CLI arguments for `axolotl evaluate` command.""" + + debug: bool = field(default=False) + debug_text_only: bool = field(default=False) + debug_num_examples: int = field(default=0) + + +@dataclass +class InferenceCliArgs: + """Dataclass with CLI arguments for `axolotl inference` command.""" + + prompter: Optional[str] = field(default=None) diff --git a/src/axolotl/cli/art.py b/src/axolotl/cli/art.py new file mode 100644 index 000000000..6ed22a52d --- /dev/null +++ b/src/axolotl/cli/art.py @@ -0,0 +1,23 @@ +"""Axolotl ASCII logo utils.""" + +from axolotl.utils.distributed import is_main_process + +AXOLOTL_LOGO = """ + #@@ #@@ @@# @@# + @@ @@ @@ @@ =@@# @@ #@ =@@#. + @@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@ + #@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@ + @@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@ + @@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@ + @@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@ + =@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@ + @@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@ + =@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@ + @@@@ @@@@@@@@@@@@@@@@ +""" + + +def print_axolotl_text_art(): + """Prints axolotl ASCII art.""" + if is_main_process(): + print(AXOLOTL_LOGO) diff --git a/src/axolotl/cli/checks.py b/src/axolotl/cli/checks.py new file mode 100644 index 000000000..cc3ed0d9f --- /dev/null +++ b/src/axolotl/cli/checks.py @@ -0,0 +1,50 @@ +"""Various checks for Axolotl CLI.""" + +import logging +import os +from pathlib import Path + +from accelerate.commands.config import config_args +from huggingface_hub import HfApi +from huggingface_hub.utils import LocalTokenNotFoundError + +from axolotl.logging_config import configure_logging + +configure_logging() +LOG = logging.getLogger(__name__) + + +def check_accelerate_default_config() -> None: + """Logs at warning level if no accelerate config file is found.""" + if Path(config_args.default_yaml_config_file).exists(): + LOG.warning( + f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors" + ) + + +def check_user_token() -> bool: + """Checks for HF user info. Check is skipped if HF_HUB_OFFLINE=1. + + Returns: + Boolean indicating successful check (i.e., HF_HUB_OFFLINE=1 or HF user info is retrieved). + + Raises: + LocalTokenNotFoundError: If HF user info can't be retrieved. + """ + # Skip check if HF_HUB_OFFLINE is set to True + if os.getenv("HF_HUB_OFFLINE") == "1": + LOG.info( + "Skipping HuggingFace token verification because HF_HUB_OFFLINE is set to True. Only local files will be used." + ) + return True + + # Verify if token is valid + api = HfApi() + try: + user_info = api.whoami() + return bool(user_info) + except LocalTokenNotFoundError: + LOG.warning( + "Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets." + ) + return False diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py new file mode 100644 index 000000000..166a67670 --- /dev/null +++ b/src/axolotl/cli/config.py @@ -0,0 +1,217 @@ +"""Configuration loading and processing.""" + +import json +import logging +import os +import tempfile +from pathlib import Path +from typing import Union +from urllib.parse import urlparse + +import requests +import torch +import yaml +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.integrations.base import PluginManager +from axolotl.utils.comet_ import setup_comet_env_vars +from axolotl.utils.config import ( + normalize_cfg_datasets, + normalize_config, + validate_config, +) +from axolotl.utils.dict import DictDefault +from axolotl.utils.mlflow_ import setup_mlflow_env_vars +from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env +from axolotl.utils.wandb_ import setup_wandb_env_vars + +LOG = logging.getLogger(__name__) + + +def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: + """ + First, determines if the passed config is a valid HTTPS URL. Then, attempts to query + for it and parse its content, first as JSON, then as YAML (YAML is preferred). + Finally, the parsed content is written to a local file and its path is returned. + + Args: + config: HTTPS URL to a YAML or JSON file. + + Returns: + Either the original `config` if it's not a valid HTTPS URL, or the path to the + downloaded remote config. + + Raises: + ValueError: If the remote configuration is neither valid JSON or YAML. + RuntimeError: If some request-related exception occurs from the file download. + Exception: Catch-all for any other exception. + """ + # Check if the config is a valid HTTPS URL to a .yml or .yaml file + if not (isinstance(config, str) and config.startswith("https://")): + return config # Return the original value if it's not a valid URL + + filename = os.path.basename(urlparse(config).path) + temp_dir = tempfile.mkdtemp() + + try: + response = requests.get(config, timeout=30) + response.raise_for_status() # Check for HTTP errors + + content = response.content + try: + # Try parsing as JSON first to catch cases where JSON content is mistakenly + # considered YAML. + json.loads(content) + + # Log a warning but do not raise an error; JSON is technically valid YAML. + # This can happen when you forget to point to a raw GitHub link. + LOG.warning( + f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended." + ) + except json.JSONDecodeError: + # If it's not valid JSON, verify it's valid YAML + try: + yaml.safe_load(content) + except yaml.YAMLError as err: + raise ValueError( + f"Failed to parse the content at {config} as YAML: {err}" + ) from err + + # Write the content to a file if it's valid YAML (or JSON treated as YAML) + output_path = Path(temp_dir) / filename + with open(output_path, "wb") as file: + file.write(content) + LOG.info( + f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n" + ) + return output_path + + except requests.RequestException as err: + # This catches all requests-related exceptions including HTTPError + raise RuntimeError(f"Failed to download {config}: {err}") from err + except Exception as err: + # Catch-all for any other exceptions + raise err + + +def choose_config(path: Path) -> str: + """ + Helper method for choosing a `axolotl` config YAML file (considering only files + ending with `.yml` or `.yaml`). If more than one config file exists in the passed + `path`, the user is prompted to choose one. + + Args: + path: Directory in which config file(s) are stored. + + Returns: + Path to either (1) the sole YAML file, or (2) if more than one YAML files exist, + the user-selected YAML file. + + Raises: + ValueError: If no YAML files are found in the given `path`. + """ + yaml_files = list(path.glob("*.yml")) + list(path.glob("*.yaml")) + + if not yaml_files: + raise ValueError( + "No YAML config files found in the specified directory. Are you using a .yml extension?" + ) + + if len(yaml_files) == 1: + print(f"Using default YAML file '{yaml_files[0]}'") + return str(yaml_files[0]) + + print("Choose a YAML file:") + for idx, file in enumerate(yaml_files): + print(f"{idx + 1}. {file}") + + chosen_file = None + while chosen_file is None: + try: + choice = int(input("Enter the number of your choice: ")) + if 1 <= choice <= len(yaml_files): + chosen_file = str(yaml_files[choice - 1]) + else: + print("Invalid choice. Please choose a number from the list.") + except ValueError: + print("Invalid input. Please enter a number.") + + return chosen_file + + +def prepare_plugins(cfg: DictDefault): + """ + Registers the plugins for the given configuration. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + """ + if cfg.get("plugins"): + plugin_manager = PluginManager.get_instance() + for plugin_name in cfg["plugins"]: + plugin_manager.register(plugin_name) + + +def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault: + """ + Loads the `axolotl` configuration stored at `config`, validates it, and performs + various setup. + + Args: + config: Path (local or remote) to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + + Returns: + `DictDefault` mapping configuration keys to values. + """ + config = check_remote_config(config) + if Path(config).is_dir(): + config = choose_config(Path(config)) + + # Load the config from the yaml file + with open(config, encoding="utf-8") as file: + cfg: DictDefault = DictDefault(yaml.safe_load(file)) + + # If there are any options passed in the cli, if it is something that seems valid + # from the yaml, then overwrite the value + cfg_keys = cfg.keys() + for k, _ in kwargs.items(): + # if not strict, allow writing to cfg even if it's not in the yml already + if k in cfg_keys or not cfg.strict: + # handle booleans + if isinstance(cfg[k], bool): + cfg[k] = bool(kwargs[k]) + else: + cfg[k] = kwargs[k] + + cfg.axolotl_config_path = config + + try: + device_props = torch.cuda.get_device_properties("cuda") + gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) + except: # pylint: disable=bare-except # noqa: E722 + gpu_version = None + + prepare_plugins(cfg) + + cfg = validate_config( + cfg, + capabilities={ + "bf16": is_torch_bf16_gpu_available(), + "n_gpu": int(os.environ.get("WORLD_SIZE", 1)), + "compute_capability": gpu_version, + }, + env_capabilities={ + "torch_version": str(torch.__version__).split("+", maxsplit=1)[0] + }, + ) + + prepare_optim_env(cfg) + prepare_opinionated_env(cfg) + normalize_config(cfg) + normalize_cfg_datasets(cfg) + setup_wandb_env_vars(cfg) + setup_mlflow_env_vars(cfg) + setup_comet_env_vars(cfg) + + return cfg diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index 8e99d6f4b..c89715719 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -1,6 +1,5 @@ -""" -CLI to run training on a model -""" +"""CLI to run evaluation on a model.""" + import logging from pathlib import Path from typing import Union @@ -9,35 +8,48 @@ import fire from dotenv import load_dotenv from transformers.hf_argparser import HfArgumentParser -from axolotl.cli import ( - check_accelerate_default_config, - check_user_token, - load_cfg, - load_datasets, - load_rl_datasets, - print_axolotl_text_art, -) -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.checks import check_accelerate_default_config, check_user_token +from axolotl.cli.config import load_cfg +from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.evaluate import evaluate +from axolotl.utils.dict import DictDefault -LOG = logging.getLogger("axolotl.cli.evaluate") +LOG = logging.getLogger(__name__) -def do_evaluate(cfg, cli_args) -> None: +def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: + """ + Evaluates a `transformers` model by first loading the dataset(s) specified in the + `axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes + evaluation metrics on the given dataset(s) and writes them to disk. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: CLI arguments. + """ # pylint: disable=duplicate-code print_axolotl_text_art() check_accelerate_default_config() check_user_token() - if cfg.rl: # and cfg.rl != "orpo": - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + if cfg.rl: + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + evaluate(cfg=cfg, dataset_meta=dataset_meta) def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: + """ + Parses `axolotl` config, CLI args, and calls `do_evaluate`. + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + """ # pylint: disable=duplicate-code parsed_cfg = load_cfg(config, **kwargs) parser = HfArgumentParser(TrainerCliArgs) diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index a5f1a8ad8..e11a39bd6 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -1,32 +1,267 @@ -""" -CLI to run inference on a trained model -""" +"""CLI to run inference on a trained model.""" + +import importlib +import logging +import sys from pathlib import Path +from threading import Thread from typing import Union import fire +import torch import transformers from dotenv import load_dotenv +from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer -from axolotl.cli import ( - do_inference, - do_inference_gradio, - load_cfg, - print_axolotl_text_art, +from axolotl.cli.args import InferenceCliArgs +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.config import load_cfg +from axolotl.cli.utils import load_model_and_tokenizer +from axolotl.utils.chat_templates import ( + get_chat_template, + get_chat_template_from_config, ) -from axolotl.common.cli import TrainerCliArgs +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger(__name__) -def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs): +def get_multi_line_input() -> str: + """ + Gets multi-line input from terminal. + + Returns: + Possibly multi-line, possibly empty stdin input as a string. + """ + print("Give me an instruction (Ctrl + D to submit): ") + + instruction = "" + for line in sys.stdin: + instruction += line # pylint: disable=consider-using-join + + return instruction + + +def do_inference( + *, + cfg: DictDefault, + cli_args: InferenceCliArgs, +): + """ + Runs inference on the command line in a loop. User input is accepted, a chat template + is (optionally) applied, and the model specified in the `axolotl` config is used to + generate completions according to a default generation config. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Inference-specific CLI arguments. + """ + model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True) + prompter = cli_args.prompter + + prompter_module = None + chat_template_str = None + if prompter: + prompter_module = getattr( + importlib.import_module("axolotl.prompters"), prompter + ) + elif cfg.chat_template: + chat_template_str = get_chat_template(cfg.chat_template) + elif cfg.datasets[0].type == "chat_template": + chat_template_str = get_chat_template_from_config( + cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer + ) + + model = model.to(cfg.device, dtype=cfg.torch_dtype) + + while True: + print("=" * 80) + # support for multiline inputs + instruction = get_multi_line_input() + if not instruction: + return + + if prompter_module: + prompt: str = next( + prompter_module().build_prompt(instruction=instruction.strip("\n")) + ) + else: + prompt = instruction.strip() + + if chat_template_str: + batch = tokenizer.apply_chat_template( + [ + { + "role": "user", + "content": prompt, + } + ], + return_tensors="pt", + add_special_tokens=True, + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=True, + return_dict=True, + ) + else: + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + + print("=" * 40) + model.eval() + with torch.no_grad(): + generation_config = GenerationConfig( + repetition_penalty=1.1, + max_new_tokens=1024, + temperature=0.9, + top_p=0.95, + top_k=40, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=True, + use_cache=True, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + streamer = TextStreamer(tokenizer) + generated = model.generate( + inputs=batch["input_ids"].to(cfg.device), + generation_config=generation_config, + streamer=streamer, + ) + print("=" * 40) + print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) + + +def do_inference_gradio( + *, + cfg: DictDefault, + cli_args: InferenceCliArgs, +): + """ + Runs inference in a Gradio interface. User input is accepted, a chat template is + (optionally) applied, and the model specified in the `axolotl` config is used to + generate completions according to a default generation config. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Inference-specific CLI arguments. + """ + import gradio as gr + + model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True) + prompter = cli_args.prompter + + prompter_module = None + chat_template_str = None + if prompter: + prompter_module = getattr( + importlib.import_module("axolotl.prompters"), prompter + ) + elif cfg.chat_template: + chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer) + + model = model.to(cfg.device, dtype=cfg.torch_dtype) + + def generate(instruction): + if not instruction: + return + if prompter_module: + # pylint: disable=stop-iteration-return + prompt: str = next( + prompter_module().build_prompt(instruction=instruction.strip("\n")) + ) + else: + prompt = instruction.strip() + + if chat_template_str: + batch = tokenizer.apply_chat_template( + [ + { + "role": "user", + "content": prompt, + } + ], + return_tensors="pt", + add_special_tokens=True, + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=True, + return_dict=True, + ) + else: + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + + model.eval() + with torch.no_grad(): + generation_config = GenerationConfig( + repetition_penalty=1.1, + max_new_tokens=cfg.get("gradio_max_new_tokens", 1024), + temperature=cfg.get("gradio_temperature", 0.9), + top_p=0.95, + top_k=40, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=True, + use_cache=True, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + streamer = TextIteratorStreamer(tokenizer) + generation_kwargs = { + "inputs": batch["input_ids"].to(cfg.device), + "attention_mask": batch["attention_mask"].to(cfg.device), + "generation_config": generation_config, + "streamer": streamer, + } + + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + all_text = "" + + for new_text in streamer: + all_text += new_text + yield all_text + + demo = gr.Interface( + fn=generate, + inputs="textbox", + outputs="text", + title=cfg.get("gradio_title", "Axolotl Gradio Interface"), + ) + + demo.queue().launch( + show_api=False, + share=cfg.get("gradio_share", True), + server_name=cfg.get("gradio_server_name", "127.0.0.1"), + server_port=cfg.get("gradio_server_port", None), + ) + + +def do_cli( + config: Union[Path, str] = Path("examples/"), gradio: bool = False, **kwargs +) -> None: + """ + Parses axolotl config, CLI args, and calls `do_inference` or `do_inference_gradio`. + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + """ # pylint: disable=duplicate-code print_axolotl_text_art() parsed_cfg = load_cfg(config, inference=True, **kwargs) parsed_cfg.sample_packing = False - parser = transformers.HfArgumentParser((TrainerCliArgs)) + parser = transformers.HfArgumentParser(InferenceCliArgs) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) - parsed_cli_args.inference = True if gradio: do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 14803e43b..43e2de3db 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -1,18 +1,20 @@ -"""CLI definition for various axolotl commands.""" +"""Click CLI definitions for various axolotl commands.""" # pylint: disable=redefined-outer-name + import subprocess # nosec B404 from typing import Optional import click import axolotl +from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.cli.utils import ( add_options_from_config, add_options_from_dataclass, build_command, fetch_from_github, + filter_none_kwargs, ) -from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig @@ -27,10 +29,16 @@ def cli(): @click.argument("config", type=click.Path(exists=True, path_type=str)) @add_options_from_dataclass(PreprocessCliArgs) @add_options_from_config(AxolotlInputConfig) -def preprocess(config: str, **kwargs): - """Preprocess datasets before training.""" - kwargs = {k: v for k, v in kwargs.items() if v is not None} +@filter_none_kwargs +def preprocess(config: str, **kwargs) -> None: + """ + Preprocess datasets before training. + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ from axolotl.cli.preprocess import do_cli do_cli(config=config, **kwargs) @@ -45,10 +53,17 @@ def preprocess(config: str, **kwargs): ) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) -def train(config: str, accelerate: bool, **kwargs): - """Train or fine-tune a model.""" - kwargs = {k: v for k, v in kwargs.items() if v is not None} +@filter_none_kwargs +def train(config: str, accelerate: bool, **kwargs) -> None: + """ + Train or fine-tune a model. + Args: + config: Path to `axolotl` config YAML file. + accelerate: Whether to use `accelerate` launcher. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ # Enable expandable segments for cuda allocation to improve VRAM usage set_pytorch_cuda_alloc_conf() @@ -73,10 +88,17 @@ def train(config: str, accelerate: bool, **kwargs): ) @add_options_from_dataclass(EvaluateCliArgs) @add_options_from_config(AxolotlInputConfig) -def evaluate(config: str, accelerate: bool, **kwargs): - """Evaluate a model.""" - kwargs = {k: v for k, v in kwargs.items() if v is not None} +@filter_none_kwargs +def evaluate(config: str, accelerate: bool, **kwargs) -> None: + """ + Evaluate a model. + Args: + config: Path to `axolotl` config YAML file. + accelerate: Whether to use `accelerate` launcher. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ if accelerate: base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"] if config: @@ -96,81 +118,33 @@ def evaluate(config: str, accelerate: bool, **kwargs): default=False, help="Use accelerate launch for multi-GPU inference", ) -@click.option( - "--lora-model-dir", - type=click.Path(exists=True, path_type=str), - help="Directory containing LoRA model", -) -@click.option( - "--base-model", - type=click.Path(exists=True, path_type=str), - help="Path to base model for non-LoRA models", -) @click.option("--gradio", is_flag=True, help="Launch Gradio interface") -@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode") @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) -def inference( - config: str, - accelerate: bool, - lora_model_dir: Optional[str] = None, - base_model: Optional[str] = None, - **kwargs, -): - """Run inference with a trained model.""" - kwargs = {k: v for k, v in kwargs.items() if v is not None} - del kwargs["inference"] # interferes with inference.do_cli - - if lora_model_dir: - kwargs["lora_model_dir"] = lora_model_dir - if base_model: - kwargs["base_model"] = base_model +@filter_none_kwargs +def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None: + """ + Run inference with a trained model. + Args: + config: Path to `axolotl` config YAML file. + accelerate: Whether to use `accelerate` launcher. + gradio: Whether to use Gradio browser interface or command line for inference. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ if accelerate: base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"] if config: base_cmd.append(config) + if gradio: + base_cmd.append("--gradio") cmd = build_command(base_cmd, kwargs) subprocess.run(cmd, check=True) # nosec B603 else: from axolotl.cli.inference import do_cli - do_cli(config=config, **kwargs) - - -@cli.command() -@click.argument("config", type=click.Path(exists=True, path_type=str)) -@click.option( - "--accelerate/--no-accelerate", - default=False, - help="Use accelerate launch for multi-GPU operations", -) -@click.option( - "--model-dir", - type=click.Path(exists=True, path_type=str), - help="Directory containing model weights to shard", -) -@click.option( - "--save-dir", - type=click.Path(path_type=str), - help="Directory to save sharded weights", -) -@add_options_from_dataclass(TrainerCliArgs) -@add_options_from_config(AxolotlInputConfig) -def shard(config: str, accelerate: bool, **kwargs): - """Shard model weights.""" - kwargs = {k: v for k, v in kwargs.items() if v is not None} - - if accelerate: - base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"] - if config: - base_cmd.append(config) - cmd = build_command(base_cmd, kwargs) - subprocess.run(cmd, check=True) # nosec B603 - else: - from axolotl.cli.shard import do_cli - - do_cli(config=config, **kwargs) + do_cli(config=config, gradio=gradio, **kwargs) @cli.command() @@ -180,20 +154,19 @@ def shard(config: str, accelerate: bool, **kwargs): default=True, help="Use accelerate launch for weight merging", ) -@click.option( - "--model-dir", - type=click.Path(exists=True, path_type=str), - help="Directory containing sharded weights", -) -@click.option( - "--save-path", type=click.Path(path_type=str), help="Path to save merged weights" -) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) -def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs): - """Merge sharded FSDP model weights.""" - kwargs = {k: v for k, v in kwargs.items() if v is not None} +@filter_none_kwargs +def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None: + """ + Merge sharded FSDP model weights. + Args: + config: Path to `axolotl` config YAML file. + accelerate: Whether to use `accelerate` launcher. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ if accelerate: base_cmd = [ "accelerate", @@ -213,28 +186,19 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs): @cli.command() @click.argument("config", type=click.Path(exists=True, path_type=str)) -@click.option( - "--lora-model-dir", - type=click.Path(exists=True, path_type=str), - help="Directory containing the LoRA model to merge", -) -@click.option( - "--output-dir", - type=click.Path(path_type=str), - help="Directory to save the merged model", -) -def merge_lora( - config: str, - lora_model_dir: Optional[str] = None, - output_dir: Optional[str] = None, -): - """Merge a trained LoRA into a base model""" - kwargs = {} - if lora_model_dir: - kwargs["lora_model_dir"] = lora_model_dir - if output_dir: - kwargs["output_dir"] = output_dir +@add_options_from_dataclass(TrainerCliArgs) +@add_options_from_config(AxolotlInputConfig) +@filter_none_kwargs +def merge_lora(config: str, **kwargs) -> None: + """ + Merge trained LoRA adapters into a base model. + Args: + config: Path to `axolotl` config YAML file. + accelerate: Whether to use `accelerate` launcher. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ from axolotl.cli.merge_lora import do_cli do_cli(config=config, **kwargs) @@ -243,13 +207,17 @@ def merge_lora( @cli.command() @click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"])) @click.option("--dest", help="Destination directory") -def fetch(directory: str, dest: Optional[str]): +def fetch(directory: str, dest: Optional[str]) -> None: """ Fetch example configs or other resources. Available directories: - examples: Example configuration files - deepspeed_configs: DeepSpeed configuration files + + Args: + directory: One of `examples`, `deepspeed_configs`. + dest: Optional destination directory. """ fetch_from_github(f"{directory}/", dest) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 8c321bc48..595eb3eab 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -1,6 +1,6 @@ -""" -CLI to run merge a trained LoRA into a base model -""" +"""CLI to merge a trained LoRA into a base model.""" + +import logging from pathlib import Path from typing import Union @@ -8,14 +8,58 @@ import fire import transformers from dotenv import load_dotenv -from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.config import load_cfg +from axolotl.cli.utils import load_model_and_tokenizer +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger(__name__) -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): - # pylint: disable=duplicate-code +def do_merge_lora(*, cfg: DictDefault) -> None: + """ + Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config + along with the LoRA adapters to combine them into a single base model. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + """ print_axolotl_text_art() - parser = transformers.HfArgumentParser((TrainerCliArgs)) + + model, tokenizer = load_model_and_tokenizer(cfg=cfg) + safe_serialization = cfg.save_safetensors is True + + LOG.info("Running merge of LoRA with base model...") + model = model.merge_and_unload(progressbar=True) + model.to(dtype=cfg.torch_dtype) + model.generation_config.do_sample = True + + if cfg.local_rank == 0: + LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...") + model.save_pretrained( + str(Path(cfg.output_dir) / "merged"), + safe_serialization=safe_serialization, + progressbar=True, + ) + tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) + + +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: + """ + Parses `axolotl` config, CLI args, and calls `do_merge_lora`. Note that various + config values will be overwritten to allow the LoRA merge logic to work as expected + (`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.). + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + + Raises: + ValueError: If target directory for LoRA merged model does not exist. + """ + # pylint: disable=duplicate-code + parser = transformers.HfArgumentParser(TrainerCliArgs) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) @@ -46,7 +90,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): parsed_cfg.fsdp = None parsed_cfg.fsdp_config = None - do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) + do_merge_lora(cfg=parsed_cfg) if __name__ == "__main__": diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py index 6be9af1f7..d4b36d92c 100644 --- a/src/axolotl/cli/merge_sharded_fsdp_weights.py +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -1,6 +1,5 @@ -""" -This module provides a CLI to merge sharded FSDP model checkpoints into a single combined checkpoint -""" +"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint.""" + import json import logging import os @@ -25,16 +24,15 @@ from huggingface_hub import split_torch_state_dict_into_shards from safetensors.torch import save_file as safe_save_file from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner -from axolotl.cli import load_cfg, print_axolotl_text_art -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.config import load_cfg -LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights") +LOG = logging.getLogger(__name__) class BFloat16CastPlanner(_EmptyStateDictLoadPlanner): - """ - A custom planner to cast tensors to bfloat16 on the fly during loading. - """ + """A custom planner to cast tensors to bfloat16 on the fly during loading.""" def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument tensor.copy_(tensor.to(torch.bfloat16)) @@ -45,11 +43,19 @@ def _distributed_checkpoint_to_merged_weights( save_path: str, safe_serialization: bool = False, max_shard_size: str = "5GB", -): +) -> Path: """ - Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save` + Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will + save under `save_path` as either `model.safetensors` or `pytorch_model.bin`. - Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`. + Args: + checkpoint_dir: Directory where distributed checkpoint is saved. + save_path: Path to save model to. + safe_serialization: Whether to save in safetensors format. + max_shard_size: Max size of model shards to save. + + Returns: + Path where model is saved. """ state_dict: Dict = {} @@ -79,6 +85,7 @@ def _distributed_checkpoint_to_merged_weights( state_dict_split = split_torch_state_dict_into_shards( state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size ) + # Save index if sharded index = None if state_dict_split.is_sharded: @@ -135,6 +142,9 @@ def merge_fsdp_weights( Whether to save the merged weights with safetensors (recommended). remove_checkpoint_dir (`bool`, *optional*, defaults to `False`): Whether to remove the checkpoint directory after merging. + + Raises: + ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist. """ checkpoint_dir_ = Path(checkpoint_dir) from accelerate.state import PartialState @@ -178,18 +188,21 @@ def merge_fsdp_weights( def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): + """ + Parses `axolotl` config, CLI args, and calls `merge_fsdp_weights`. + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + """ # pylint: disable=duplicate-code print_axolotl_text_art() - parser = transformers.HfArgumentParser((TrainerCliArgs)) + parser = transformers.HfArgumentParser(TrainerCliArgs) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) parsed_cli_args.merge_lora = True - - parsed_cfg = load_cfg( - config, - **kwargs, - ) + parsed_cfg = load_cfg(config, **kwargs) fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0" merge_fsdp_weights( diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index a1592aa78..760fe76fa 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -1,6 +1,5 @@ -""" -CLI to run training on a model -""" +"""CLI to run preprocessing of a dataset.""" + import logging import warnings from pathlib import Path @@ -13,34 +12,31 @@ from colorama import Fore from dotenv import load_dotenv from transformers import AutoModelForCausalLM -from axolotl.cli import ( - check_accelerate_default_config, - check_user_token, - load_cfg, - load_datasets, - load_rl_datasets, - print_axolotl_text_art, -) -from axolotl.common.cli import PreprocessCliArgs +from axolotl.cli.args import PreprocessCliArgs +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.checks import check_accelerate_default_config, check_user_token +from axolotl.cli.config import load_cfg from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH +from axolotl.common.datasets import load_datasets, load_preference_datasets +from axolotl.utils.dict import DictDefault from axolotl.utils.trainer import disable_datasets_caching -LOG = logging.getLogger("axolotl.cli.preprocess") +LOG = logging.getLogger(__name__) -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): - # pylint: disable=duplicate-code +def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: + """ + Preprocesses dataset specified in axolotl config. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Preprocessing-specific CLI arguments. + """ print_axolotl_text_art() - parsed_cfg = load_cfg(config, **kwargs) - parsed_cfg.is_preprocess = True check_accelerate_default_config() check_user_token() - parser = transformers.HfArgumentParser((PreprocessCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - if not parsed_cfg.dataset_prepared_path: + if not cfg.dataset_prepared_path: msg = ( Fore.RED + "preprocess CLI called without dataset_prepared_path set, " @@ -48,16 +44,16 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): + Fore.RESET ) LOG.warning(msg) - parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH + cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH with disable_datasets_caching(): - if parsed_cfg.rl: # and parsed_cfg.rl != "orpo": - load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + if cfg.rl: + load_preference_datasets(cfg=cfg, cli_args=cli_args) else: - load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + load_datasets(cfg=cfg, cli_args=cli_args) - if parsed_cli_args.download: - model_name = parsed_cfg.base_model + if cli_args.download: + model_name = cfg.base_model with warnings.catch_warnings(): # there are a bunch of useless UserWarnings about # "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model" @@ -74,11 +70,30 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): LOG.info( Fore.GREEN - + f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`" + + f"Success! Preprocessed data path: `dataset_prepared_path: {cfg.dataset_prepared_path}`" + Fore.RESET ) +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: + """ + Parses `axolotl` config, CLI args, and calls `do_preprocess`. + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + """ + # pylint: disable=duplicate-code + parsed_cfg = load_cfg(config, **kwargs) + parsed_cfg.is_preprocess = True + parser = transformers.HfArgumentParser(PreprocessCliArgs) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + + do_preprocess(parsed_cfg, parsed_cli_args) + + if __name__ == "__main__": load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/shard.py b/src/axolotl/cli/shard.py deleted file mode 100644 index 196c0e99a..000000000 --- a/src/axolotl/cli/shard.py +++ /dev/null @@ -1,45 +0,0 @@ -""" -CLI to shard a trained model into 10GiB chunks -""" -import logging -from pathlib import Path -from typing import Union - -import fire -import transformers -from dotenv import load_dotenv - -from axolotl.cli import load_cfg, print_axolotl_text_art -from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer -from axolotl.utils.dict import DictDefault - -LOG = logging.getLogger("axolotl.scripts") - - -def shard( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - safe_serialization = cfg.save_safetensors is True - LOG.debug("Re-saving model w/ sharding") - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) - - -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): - # pylint: disable=duplicate-code - print_axolotl_text_art() - parsed_cfg = load_cfg(config, **kwargs) - parser = transformers.HfArgumentParser((TrainerCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - parsed_cli_args.shard = True - - shard(cfg=parsed_cfg, cli_args=parsed_cli_args) - - -if __name__ == "__main__": - load_dotenv() - fire.Fire(do_cli) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 2a40e854e..9e3ae1cc3 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -1,6 +1,5 @@ -""" -CLI to run training on a model -""" +"""CLI to run training on a model.""" + import logging from pathlib import Path from typing import Union @@ -9,42 +8,38 @@ import fire from dotenv import load_dotenv from transformers.hf_argparser import HfArgumentParser -from axolotl.cli import ( - check_accelerate_default_config, - check_user_token, - load_cfg, - load_datasets, - load_rl_datasets, - print_axolotl_text_art, -) -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.checks import check_accelerate_default_config, check_user_token +from axolotl.cli.config import load_cfg +from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.integrations.base import PluginManager from axolotl.train import train +from axolotl.utils.dict import DictDefault -LOG = logging.getLogger("axolotl.cli.train") +LOG = logging.getLogger(__name__) -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): - # pylint: disable=duplicate-code - parsed_cfg = load_cfg(config, **kwargs) - parser = HfArgumentParser((TrainerCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - return do_train(parsed_cfg, parsed_cli_args) +def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: + """ + Trains a `transformers` model by first loading the dataset(s) specified in the + `axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin + manager's `post_train_unload` once training completes. - -def do_train(cfg, cli_args) -> None: + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Training-specific CLI arguments. + """ print_axolotl_text_art() check_accelerate_default_config() check_user_token() - if cfg.rl: # and cfg.rl != "orpo": - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + if cfg.rl: + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta) plugin_manager = PluginManager.get_instance() del model @@ -53,6 +48,24 @@ def do_train(cfg, cli_args) -> None: plugin_manager.post_train_unload(cfg) +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: + """ + Parses `axolotl` config, CLI args, and calls `do_train`. + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + """ + # pylint: disable=duplicate-code + parsed_cfg = load_cfg(config, **kwargs) + parser = HfArgumentParser(TrainerCliArgs) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + + do_train(parsed_cfg, parsed_cli_args) + + if __name__ == "__main__": load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index 85d241b5d..addfa0ab9 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -1,32 +1,85 @@ -"""Utility methods for axoltl CLI.""" +"""Utility methods for axolotl CLI.""" + import concurrent.futures import dataclasses import hashlib import json import logging +import typing +from functools import wraps from pathlib import Path from types import NoneType -from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin +from typing import Any, Callable, Type, Union, get_args, get_origin import click import requests from pydantic import BaseModel +from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast -LOG = logging.getLogger("axolotl.cli.utils") +from axolotl.logging_config import configure_logging +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer + +configure_logging() +LOG = logging.getLogger(__name__) -def add_options_from_dataclass(config_class: Type[Any]): - """Create Click options from the fields of a dataclass.""" +def strip_optional_type(field_type: type | typing._SpecialForm | None): + """ + Extracts the non-`None` type from an `Optional` / `Union` type. - def decorator(function): + Args: + field_type: Type of field for Axolotl CLI command. + + Returns: + If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise + returns the input type unchanged. + """ + if get_origin(field_type) is Union and type(None) in get_args(field_type): + field_type = next( + t for t in get_args(field_type) if not isinstance(t, NoneType) + ) + + return field_type + + +def filter_none_kwargs(func: Callable) -> Callable: + """ + Wraps function to remove `None`-valued `kwargs`. + + Args: + func: Function to wrap. + + Returns: + Wrapped function. + """ + + @wraps(func) + def wrapper(*args, **kwargs) -> Callable: + """Filters out `None`-valued `kwargs`.""" + filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} + + return func(*args, **filtered_kwargs) + + return wrapper + + +def add_options_from_dataclass(config_class: Type[Any]) -> Callable: + """ + Create Click options from the fields of a dataclass. + + Args: + config_class: Dataclass with fields to parse from the CLI. + + Returns: + Function decorator for Axolotl CLI command. + """ + + def decorator(function: Callable) -> Callable: # Process dataclass fields in reverse order for correct option ordering for field in reversed(dataclasses.fields(config_class)): - field_type = field.type + field_type = strip_optional_type(field.type) - if get_origin(field_type) is Union and type(None) in get_args(field_type): - field_type = next( - t for t in get_args(field_type) if not isinstance(t, NoneType) - ) if field_type == bool: field_name = field.name.replace("_", "-") option_name = f"--{field_name}/--no-{field_name}" @@ -43,18 +96,29 @@ def add_options_from_dataclass(config_class: Type[Any]): default=field.default, help=field.metadata.get("description"), )(function) + return function return decorator -def add_options_from_config(config_class: Type[BaseModel]): - """Create Click options from the fields of a Pydantic model.""" +def add_options_from_config(config_class: Type[BaseModel]) -> Callable: + """ + Create Click options from the fields of a Pydantic model. - def decorator(function): + Args: + config_class: PyDantic model with fields to parse from the CLI + + Returns: + Function decorator for Axolotl CLI command. + """ + + def decorator(function: Callable) -> Callable: # Process model fields in reverse order for correct option ordering for name, field in reversed(config_class.model_fields.items()): - if field.annotation == bool: + field_type = strip_optional_type(field.annotation) + + if field_type == bool: field_name = name.replace("_", "-") option_name = f"--{field_name}/--no-{field_name}" function = click.option( @@ -65,13 +129,23 @@ def add_options_from_config(config_class: Type[BaseModel]): function = click.option( option_name, default=None, help=field.description )(function) + return function return decorator -def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]: - """Build command list from base command and options.""" +def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]: + """ + Build command list from base command and options. + + Args: + base_cmd: Command without options. + options: Options to parse and append to base command. + + Returns: + List of strings giving shell command. + """ cmd = base_cmd.copy() for key, value in options.items(): @@ -91,18 +165,18 @@ def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]: def download_file( file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str -) -> Tuple[str, str]: +) -> tuple[str, str]: """ Download a single file and return its processing status. Args: - file_info: Tuple of (file_path, remote_sha) - raw_base_url: Base URL for raw GitHub content - dest_path: Local destination directory - dir_prefix: Directory prefix to filter files + file_info: Tuple of (file_path, remote_sha). + raw_base_url: Base URL for raw GitHub content. + dest_path: Local destination directory. + dir_prefix: Directory prefix to filter files. Returns: - Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged' + Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'. """ file_path, remote_sha = file_info raw_url = f"{raw_base_url}/{file_path}" @@ -144,16 +218,17 @@ def download_file( def fetch_from_github( - dir_prefix: str, dest_dir: Optional[str] = None, max_workers: int = 5 + dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5 ) -> None: """ Sync files from a specific directory in the GitHub repository. Only downloads files that don't exist locally or have changed. Args: - dir_prefix: Directory prefix to filter files (e.g., 'examples/', 'deepspeed_configs/') - dest_dir: Local destination directory - max_workers: Maximum number of concurrent downloads + dir_prefix: Directory prefix to filter files (e.g., 'examples/', + 'deepspeed_configs/'). + dest_dir: Local destination directory. + max_workers: Maximum number of concurrent downloads. """ api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1" raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main" @@ -178,7 +253,7 @@ def fetch_from_github( dest_path = Path(dest_dir) if dest_dir else default_dest # Keep track of processed files for summary - files_processed: Dict[str, List[str]] = { + files_processed: dict[str, list[str]] = { "new": [], "updated": [], "unchanged": [], @@ -215,3 +290,28 @@ def fetch_from_github( LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}") if files_processed["error"]: LOG.info(f"Failed files: {len(files_processed['error'])}") + + +def load_model_and_tokenizer( + *, + cfg: DictDefault, + inference: bool = False, +) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]: + """ + Helper function for loading a model and tokenizer specified in the given `axolotl` + config. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + inference: Boolean denoting inference mode. + + Returns: + `transformers` model and tokenizer. + """ + LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") + tokenizer = load_tokenizer(cfg) + + LOG.info("loading model...") + model, _ = load_model(cfg, tokenizer, inference=inference) + + return model, tokenizer diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py deleted file mode 100644 index 02ad9201b..000000000 --- a/src/axolotl/common/cli.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -shared module for cli specific things -""" - -import logging -from dataclasses import dataclass, field -from typing import Optional - -import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 -from axolotl.logging_config import configure_logging -from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_tokenizer - -configure_logging() -LOG = logging.getLogger("axolotl.common.cli") - - -@dataclass -class PreprocessCliArgs: - """ - dataclass representing arguments for preprocessing only - """ - - debug: bool = field(default=False) - debug_text_only: bool = field(default=False) - debug_num_examples: int = field(default=1) - prompter: Optional[str] = field(default=None) - download: Optional[bool] = field(default=True) - - -@dataclass -class TrainerCliArgs: - """ - dataclass representing the various non-training arguments - """ - - debug: bool = field(default=False) - debug_text_only: bool = field(default=False) - debug_num_examples: int = field(default=0) - inference: bool = field(default=False) - merge_lora: bool = field(default=False) - prompter: Optional[str] = field(default=None) - shard: bool = field(default=False) - - -@dataclass -class EvaluateCliArgs: - """ - dataclass representing the various evaluation arguments - """ - - debug: bool = field(default=False) - debug_text_only: bool = field(default=False) - debug_num_examples: int = field(default=0) - - -def load_model_and_tokenizer( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") - tokenizer = load_tokenizer(cfg) - - LOG.info("loading model and (optionally) peft_config...") - inference = getattr(cli_args, "inference", False) - model, _ = load_model(cfg, tokenizer, inference=inference) - - return model, tokenizer diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py new file mode 100644 index 000000000..d07add29b --- /dev/null +++ b/src/axolotl/common/datasets.py @@ -0,0 +1,140 @@ +"""Dataset loading utilities.""" + +import logging +import math +import random +from dataclasses import dataclass +from typing import Optional, Union + +from datasets import Dataset + +import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 +from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs +from axolotl.utils.data import prepare_dataset +from axolotl.utils.data.rl import load_prepare_dpo_datasets +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_processor, load_tokenizer +from axolotl.utils.tokenization import check_dataset_labels + +LOG = logging.getLogger(__name__) + + +@dataclass +class TrainDatasetMeta: + """Dataclass with fields for training and validation datasets and metadata.""" + + train_dataset: Dataset + eval_dataset: Optional[Dataset] = None + total_num_steps: Optional[int] = None + + +def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset: + """ + Randomly sample `num_samples` samples from `dataset`. + + Args: + dataset: Dataset. + num_samples: Number of samples to return. + + Returns: + Random sample (with replacement) of examples in `dataset`. + """ + return dataset.select( + [random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec + ) + + +def load_datasets( + *, + cfg: DictDefault, + cli_args: Union[PreprocessCliArgs, TrainerCliArgs], +) -> TrainDatasetMeta: + """ + Loads one or more training or evaluation datasets, calling + `axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Command-specific CLI arguments. + + Returns: + Dataclass with fields for training and evaluation datasets and the computed + `total_num_steps`. + """ + tokenizer = load_tokenizer(cfg) + processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None + + train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( + cfg, + tokenizer, + processor=processor, + ) + + if ( + cli_args.debug + or cfg.debug + or cli_args.debug_text_only + or int(cli_args.debug_num_examples) > 0 + ): + LOG.info("check_dataset_labels...") + + train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples) + check_dataset_labels( + train_samples, + tokenizer, + num_examples=cli_args.debug_num_examples, + text_only=cli_args.debug_text_only, + ) + + LOG.info("printing prompters...") + for prompter in prompters: + LOG.info(prompter) + + return TrainDatasetMeta( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + total_num_steps=total_num_steps, + ) + + +def load_preference_datasets( + *, + cfg: DictDefault, + cli_args: Union[PreprocessCliArgs, TrainerCliArgs], +) -> TrainDatasetMeta: + """ + Loads one or more training or evaluation datasets for DPO training, calling + `axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug + information. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Command-specific CLI arguments. + + Returns: + Dataclass with fields for training and evaluation datasets and the computed + `total_num_steps`. + """ + train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) + total_num_steps = int( + math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) + ) + + if cli_args.debug or cfg.debug: + LOG.info("check_dataset_labels...") + + tokenizer = load_tokenizer(cfg) + train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples) + check_dataset_labels( + train_samples, + tokenizer, + num_examples=cli_args.debug_num_examples, + text_only=cli_args.debug_text_only, + rl_mode=True, + ) + + return TrainDatasetMeta( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + total_num_steps=total_num_steps, + ) diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index acf15e3fc..8d9ddc6ab 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -9,7 +9,6 @@ from typing import Dict, Optional import torch from accelerate.logging import get_logger -from axolotl.common.cli import TrainerCliArgs from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta from axolotl.utils import set_pytorch_cuda_alloc_conf @@ -62,16 +61,13 @@ def evaluate_dataset( return metrics -def evaluate( - *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta -) -> Dict[str, float]: +def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]: """ Evaluate a model on training and validation datasets Args: - cfg: Configuration dictionary - cli_args: Command line arguments - dataset_meta: Dataset metadata containing training and evaluation datasets + cfg: Dictionary mapping `axolotl` config keys to values. + dataset_meta: Dataset metadata containing training and evaluation datasets. Returns: Tuple containing: @@ -102,9 +98,7 @@ def evaluate( # Load model LOG.debug("loading model for evaluation...") - model, _ = load_model( - cfg, tokenizer, processor=processor, inference=cli_args.inference - ) + model, _ = load_model(cfg, tokenizer, processor=processor) # Set up trainer trainer = setup_trainer( diff --git a/src/axolotl/train.py b/src/axolotl/train.py index a74ecc2ec..b901c2a97 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -5,21 +5,19 @@ import os import signal import sys import weakref -from dataclasses import dataclass from pathlib import Path -from typing import Optional, Tuple, Union +from typing import Tuple, Union import torch import transformers.modelcard from accelerate.logging import get_logger from accelerate.utils import save_fsdp_model -from datasets import Dataset from peft import PeftModel from pkg_resources import get_distribution # type: ignore from transformers import PreTrainedModel, PreTrainedTokenizer from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled -from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import TrainDatasetMeta from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) @@ -39,22 +37,11 @@ src_dir = os.path.join(project_root, "src") sys.path.insert(0, src_dir) configure_logging() -LOG = get_logger("axolotl.train") - - -@dataclass -class TrainDatasetMeta: - """ - dataclass to capture the dataset specific options for training - """ - - train_dataset: Dataset - eval_dataset: Optional[Dataset] = None - total_num_steps: Optional[int] = None +LOG = get_logger(__name__) def train( - *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta + *, cfg: DictDefault, dataset_meta: TrainDatasetMeta ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: # Load tokenizer LOG.debug( @@ -93,9 +80,7 @@ def train( if cfg.adapter: msg += " and peft_config..." LOG.debug(msg) - model, peft_config = load_model( - cfg, tokenizer, processor=processor, inference=cli_args.inference - ) + model, peft_config = load_model(cfg, tokenizer, processor=processor) if model.generation_config is not None: model.generation_config.do_sample = True @@ -107,9 +92,7 @@ def train( model_ref = None # explicit setting to None else: # load the model again for model_ref/baseline - model_ref, _ = load_model( - cfg, tokenizer, inference=cli_args.inference, reference_model=True - ) + model_ref, _ = load_model(cfg, tokenizer, reference_model=True) safe_serialization = cfg.save_safetensors is True diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index aff047675..de373c06e 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -109,7 +109,9 @@ def prepare_dataset(cfg, tokenizer, processor=None): cfg.pretraining_dataset[0]["type"] or "pretrain", ) - iter_ds = load_dataset(path, streaming=True, split=split, name=name, data_files=data_files) + iter_ds = load_dataset( + path, streaming=True, split=split, name=name, data_files=data_files + ) if skip: LOG.info(f"Skipping {skip} samples from the dataset") iter_ds = iter_ds.skip(skip) diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index 78b090e19..d360e29d6 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -1,4 +1,5 @@ """Shared pytest fixtures for cli module.""" + import pytest from click.testing import CliRunner diff --git a/tests/cli/test_cli_fetch.py b/tests/cli/test_cli_fetch.py index 0df87b029..f06f06717 100644 --- a/tests/cli/test_cli_fetch.py +++ b/tests/cli/test_cli_fetch.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI fetch command.""" + from unittest.mock import patch from axolotl.cli.main import fetch diff --git a/tests/cli/test_cli_inference.py b/tests/cli/test_cli_inference.py index 7cb163d25..b8effa3d2 100644 --- a/tests/cli/test_cli_inference.py +++ b/tests/cli/test_cli_inference.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI inference command.""" + from unittest.mock import patch from axolotl.cli.main import cli diff --git a/tests/cli/test_cli_interface.py b/tests/cli/test_cli_interface.py index ed8335b76..8b5fec17f 100644 --- a/tests/cli/test_cli_interface.py +++ b/tests/cli/test_cli_interface.py @@ -1,4 +1,5 @@ """General pytest tests for axolotl.cli.main interface.""" + from axolotl.cli.main import build_command, cli diff --git a/tests/cli/test_cli_merge_lora.py b/tests/cli/test_cli_merge_lora.py index 165a64e98..aac016760 100644 --- a/tests/cli/test_cli_merge_lora.py +++ b/tests/cli/test_cli_merge_lora.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI merge_lora command.""" + from unittest.mock import patch from axolotl.cli.main import cli diff --git a/tests/cli/test_cli_merge_sharded_fsdp_weights.py b/tests/cli/test_cli_merge_sharded_fsdp_weights.py index cff0f3b77..18589a80d 100644 --- a/tests/cli/test_cli_merge_sharded_fsdp_weights.py +++ b/tests/cli/test_cli_merge_sharded_fsdp_weights.py @@ -1,5 +1,6 @@ """pytest tests for axolotl CLI merge_sharded_fsdp_weights command.""" # pylint: disable=duplicate-code + from unittest.mock import patch from axolotl.cli.main import cli @@ -15,46 +16,3 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path): assert mock.called assert mock.call_args.kwargs["config"] == str(config_path) assert result.exit_code == 0 - - -def test_merge_sharded_fsdp_weights_with_model_dir(cli_runner, config_path, tmp_path): - """Test merge_sharded_fsdp_weights command with model_dir option""" - model_dir = tmp_path / "model" - model_dir.mkdir() - - with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock: - result = cli_runner.invoke( - cli, - [ - "merge-sharded-fsdp-weights", - str(config_path), - "--no-accelerate", - "--model-dir", - str(model_dir), - ], - ) - - assert mock.called - assert mock.call_args.kwargs["config"] == str(config_path) - assert mock.call_args.kwargs["model_dir"] == str(model_dir) - assert result.exit_code == 0 - - -def test_merge_sharded_fsdp_weights_with_save_path(cli_runner, config_path): - """Test merge_sharded_fsdp_weights command with save_path option""" - with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock: - result = cli_runner.invoke( - cli, - [ - "merge-sharded-fsdp-weights", - str(config_path), - "--no-accelerate", - "--save-path", - "/path/to/save", - ], - ) - - assert mock.called - assert mock.call_args.kwargs["config"] == str(config_path) - assert mock.call_args.kwargs["save_path"] == "/path/to/save" - assert result.exit_code == 0 diff --git a/tests/cli/test_cli_preprocess.py b/tests/cli/test_cli_preprocess.py index 4719461aa..e2dd3a6c3 100644 --- a/tests/cli/test_cli_preprocess.py +++ b/tests/cli/test_cli_preprocess.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI preprocess command.""" + import shutil from pathlib import Path from unittest.mock import patch diff --git a/tests/cli/test_cli_shard.py b/tests/cli/test_cli_shard.py deleted file mode 100644 index 505a2a737..000000000 --- a/tests/cli/test_cli_shard.py +++ /dev/null @@ -1,76 +0,0 @@ -"""pytest tests for axolotl CLI shard command.""" -# pylint: disable=duplicate-code -from unittest.mock import patch - -from axolotl.cli.main import cli - - -def test_shard_with_accelerate(cli_runner, config_path): - """Test shard command with accelerate""" - with patch("subprocess.run") as mock: - result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"]) - - assert mock.called - assert mock.call_args.args[0] == [ - "accelerate", - "launch", - "-m", - "axolotl.cli.shard", - str(config_path), - "--debug-num-examples", - "0", - ] - assert mock.call_args.kwargs == {"check": True} - assert result.exit_code == 0 - - -def test_shard_no_accelerate(cli_runner, config_path): - """Test shard command without accelerate""" - with patch("axolotl.cli.shard.do_cli") as mock: - result = cli_runner.invoke(cli, ["shard", str(config_path), "--no-accelerate"]) - - assert mock.called - assert result.exit_code == 0 - - -def test_shard_with_model_dir(cli_runner, config_path, tmp_path): - """Test shard command with model_dir option""" - model_dir = tmp_path / "model" - model_dir.mkdir() - - with patch("axolotl.cli.shard.do_cli") as mock: - result = cli_runner.invoke( - cli, - [ - "shard", - str(config_path), - "--no-accelerate", - "--model-dir", - str(model_dir), - ], - catch_exceptions=False, - ) - - assert mock.called - assert mock.call_args.kwargs["config"] == str(config_path) - assert mock.call_args.kwargs["model_dir"] == str(model_dir) - assert result.exit_code == 0 - - -def test_shard_with_save_dir(cli_runner, config_path): - with patch("axolotl.cli.shard.do_cli") as mock: - result = cli_runner.invoke( - cli, - [ - "shard", - str(config_path), - "--no-accelerate", - "--save-dir", - "/path/to/save", - ], - ) - - assert mock.called - assert mock.call_args.kwargs["config"] == str(config_path) - assert mock.call_args.kwargs["save_dir"] == "/path/to/save" - assert result.exit_code == 0 diff --git a/tests/cli/test_cli_version.py b/tests/cli/test_cli_version.py index 819780e94..533dd5c0e 100644 --- a/tests/cli/test_cli_version.py +++ b/tests/cli/test_cli_version.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI --version""" + from axolotl.cli.main import cli diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index b88e4ac72..ecb0025e4 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -1,5 +1,6 @@ """pytest tests for axolotl CLI utils.""" # pylint: disable=redefined-outer-name + import json from unittest.mock import Mock, patch diff --git a/tests/e2e/integrations/test_cut_cross_entropy.py b/tests/e2e/integrations/test_cut_cross_entropy.py index 6562af176..291a4a4ec 100644 --- a/tests/e2e/integrations/test_cut_cross_entropy.py +++ b/tests/e2e/integrations/test_cut_cross_entropy.py @@ -4,8 +4,8 @@ Simple end-to-end test for Cut Cross Entropy integration import pytest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils import get_pytorch_version from axolotl.utils.config import normalize_config, prepare_plugins @@ -64,9 +64,9 @@ class TestCutCrossEntropyIntegration: major, minor, _ = get_pytorch_version() if (major, minor) < (2, 4): with pytest.raises(ImportError): - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) else: - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @pytest.mark.parametrize( @@ -92,7 +92,7 @@ class TestCutCrossEntropyIntegration: major, minor, _ = get_pytorch_version() if (major, minor) < (2, 4): with pytest.raises(ImportError): - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) else: - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py index 9154bf9b8..1efe889e4 100644 --- a/tests/e2e/integrations/test_liger.py +++ b/tests/e2e/integrations/test_liger.py @@ -4,8 +4,8 @@ Simple end-to-end test for Liger integration from e2e.utils import require_torch_2_4_1 -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.dict import DictDefault @@ -60,7 +60,7 @@ class LigerIntegrationTestCase: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @require_torch_2_4_1 @@ -105,5 +105,5 @@ class LigerIntegrationTestCase: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py index 08b3bf0da..da27069ac 100644 --- a/tests/e2e/patched/test_4d_multipack_llama.py +++ b/tests/e2e/patched/test_4d_multipack_llama.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -65,7 +65,7 @@ class Test4dMultipackLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -109,5 +109,5 @@ class Test4dMultipackLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_cli_integrations.py b/tests/e2e/patched/test_cli_integrations.py index 6ca7c52ae..ce9396d5f 100644 --- a/tests/e2e/patched/test_cli_integrations.py +++ b/tests/e2e/patched/test_cli_integrations.py @@ -5,7 +5,7 @@ from pathlib import Path import yaml -from axolotl.cli import load_cfg +from axolotl.cli.config import load_cfg from axolotl.utils.dict import DictDefault diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py index 791d955b2..2bfd36d15 100644 --- a/tests/e2e/patched/test_fa_xentropy.py +++ b/tests/e2e/patched/test_fa_xentropy.py @@ -8,8 +8,8 @@ import os import pytest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -80,7 +80,7 @@ class TestFAXentropyLlama: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py index 69516810f..62ee4f717 100644 --- a/tests/e2e/patched/test_falcon_samplepack.py +++ b/tests/e2e/patched/test_falcon_samplepack.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -67,7 +67,7 @@ class TestFalconPatched(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -107,5 +107,5 @@ class TestFalconPatched(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py index 23a0adfc0..e7ab510c9 100644 --- a/tests/e2e/patched/test_fused_llama.py +++ b/tests/e2e/patched/test_fused_llama.py @@ -9,8 +9,8 @@ import unittest import pytest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -71,5 +71,5 @@ class TestFusedLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index d0fdd918a..8d0ba6c2a 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -8,8 +8,8 @@ import unittest import pytest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -69,7 +69,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -109,5 +109,5 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py index 634e544d2..bc18e3d81 100644 --- a/tests/e2e/patched/test_lora_llama_multipack.py +++ b/tests/e2e/patched/test_lora_llama_multipack.py @@ -9,8 +9,8 @@ import unittest import pytest from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -74,7 +74,7 @@ class TestLoraLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available") @@ -124,5 +124,5 @@ class TestLoraLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py index e93863e09..c7fd0ecbc 100644 --- a/tests/e2e/patched/test_mistral_samplepack.py +++ b/tests/e2e/patched/test_mistral_samplepack.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -67,7 +67,7 @@ class TestMistral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -108,5 +108,5 @@ class TestMistral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index f87c34fd1..156dac7e8 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -64,7 +64,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -102,7 +102,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( "MixtralFlashAttention2" in model.model.layers[0].self_attn.__class__.__name__ diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py index 170c37fd6..78b01be64 100644 --- a/tests/e2e/patched/test_model_patches.py +++ b/tests/e2e/patched/test_model_patches.py @@ -6,7 +6,6 @@ import unittest import transformers -from axolotl.common.cli import TrainerCliArgs from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer @@ -49,9 +48,8 @@ class TestModelPatches(unittest.TestCase): } ) normalize_config(cfg) - cli_args = TrainerCliArgs() tokenizer = load_tokenizer(cfg) - model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + model, _ = load_model(cfg, tokenizer, inference=False) assert ( "MixtralFlashAttention2" @@ -87,9 +85,8 @@ class TestModelPatches(unittest.TestCase): } ) normalize_config(cfg) - cli_args = TrainerCliArgs() tokenizer = load_tokenizer(cfg) - load_model(cfg, tokenizer, inference=cli_args.inference) + load_model(cfg, tokenizer, inference=False) assert ( "torch.jit" diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py index 852ac7bec..ce466460e 100644 --- a/tests/e2e/patched/test_phi_multipack.py +++ b/tests/e2e/patched/test_phi_multipack.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -67,7 +67,7 @@ class TestPhiMultipack(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -118,5 +118,5 @@ class TestPhiMultipack(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py index 5639d2eae..f6a3e0109 100644 --- a/tests/e2e/patched/test_resume.py +++ b/tests/e2e/patched/test_resume.py @@ -9,8 +9,8 @@ import subprocess from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -71,7 +71,7 @@ class TestResumeLlama: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) resume_cfg = cfg | DictDefault( { @@ -81,7 +81,7 @@ class TestResumeLlama: normalize_config(resume_cfg) cli_args = TrainerCliArgs() - train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=resume_cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) tb_log_path_1 = most_recent_subdir(temp_dir + "/runs") diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index 492bc1c23..da5eaffb6 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -6,8 +6,8 @@ import os import pytest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -75,7 +75,7 @@ class TestUnslothQLoRA: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( @@ -125,7 +125,7 @@ class TestUnslothQLoRA: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( @@ -180,7 +180,7 @@ class TestUnslothQLoRA: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index f8109373a..2d0baceee 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -9,8 +9,8 @@ from pathlib import Path import pytest -from axolotl.cli import load_rl_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_preference_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -65,9 +65,9 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir @@ -110,9 +110,9 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir @@ -155,9 +155,9 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @pytest.mark.skip("kto_pair no longer supported in trl") @@ -200,9 +200,9 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir @@ -244,9 +244,9 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir @@ -291,9 +291,9 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @pytest.mark.skip(reason="Fix the implementation") @@ -355,7 +355,7 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py index 222d620ae..4261ccc26 100644 --- a/tests/e2e/test_embeddings_lr.py +++ b/tests/e2e/test_embeddings_lr.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -60,7 +60,7 @@ class TestEmbeddingsLrScale(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( @@ -104,7 +104,7 @@ class TestEmbeddingsLrScale(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py index 117de6635..ddcb66275 100644 --- a/tests/e2e/test_falcon.py +++ b/tests/e2e/test_falcon.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -69,7 +69,7 @@ class TestFalcon(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -122,7 +122,7 @@ class TestFalcon(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -161,5 +161,5 @@ class TestFalcon(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index 4384bb61e..a94828490 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -7,8 +7,8 @@ import os from e2e.utils import check_model_output_exists -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -60,7 +60,7 @@ class TestLlama: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) def test_fix_untrained_tokens(self, temp_dir): @@ -103,7 +103,7 @@ class TestLlama: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) def test_batch_flattening(self, temp_dir): @@ -142,5 +142,5 @@ class TestLlama: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index d13b10659..68cd490be 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -62,5 +62,5 @@ class TestPretrainLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py index 250cf418c..91f101e44 100644 --- a/tests/e2e/test_llama_vision.py +++ b/tests/e2e/test_llama_vision.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -66,7 +66,7 @@ class TestLlamaVision(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -111,5 +111,5 @@ class TestLlamaVision(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index a7ead64a5..696c47aed 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -63,5 +63,5 @@ class TestLoraLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py index a1fc30862..4b4db3058 100644 --- a/tests/e2e/test_mamba.py +++ b/tests/e2e/test_mamba.py @@ -8,8 +8,8 @@ import unittest import pytest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -63,5 +63,5 @@ class TestMamba(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py index 2e79fec8d..a304e9b4a 100644 --- a/tests/e2e/test_mistral.py +++ b/tests/e2e/test_mistral.py @@ -8,8 +8,8 @@ import unittest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -67,7 +67,7 @@ class TestMistral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -110,5 +110,5 @@ class TestMistral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py index 6792d05a6..6e06626f6 100644 --- a/tests/e2e/test_mixtral.py +++ b/tests/e2e/test_mixtral.py @@ -9,8 +9,8 @@ import unittest import torch from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -73,7 +73,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 @@ -127,7 +127,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 @@ -184,7 +184,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 @@ -241,7 +241,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 @@ -285,5 +285,5 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index f1bbaafd5..453872538 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -63,7 +63,7 @@ class TestCustomOptimizers(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -107,7 +107,7 @@ class TestCustomOptimizers(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -143,5 +143,5 @@ class TestCustomOptimizers(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py index dd0af32f3..13244a215 100644 --- a/tests/e2e/test_packing_loss.py +++ b/tests/e2e/test_packing_loss.py @@ -8,8 +8,8 @@ import unittest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -63,7 +63,7 @@ class TestPackedLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index 7a08d0c6f..54f564d0e 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -65,7 +65,7 @@ class TestPhi(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -114,5 +114,5 @@ class TestPhi(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_relora_llama.py b/tests/e2e/test_relora_llama.py index fef6a3d30..6c785dc86 100644 --- a/tests/e2e/test_relora_llama.py +++ b/tests/e2e/test_relora_llama.py @@ -7,8 +7,8 @@ import os import unittest from pathlib import Path -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -77,7 +77,7 @@ class TestReLoraLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg) assert ( Path(temp_dir) / "checkpoint-100/relora/model.safetensors" diff --git a/tests/e2e/test_reward_model_llama.py b/tests/e2e/test_reward_model_llama.py index c4cb705ea..4cd8602f3 100644 --- a/tests/e2e/test_reward_model_llama.py +++ b/tests/e2e/test_reward_model_llama.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -69,5 +69,5 @@ class TestRewardModelLoraLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) From 19cd83d408ba0d46f2cf6e285488001eeaf4d1c1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 14 Jan 2025 22:07:55 -0500 Subject: [PATCH 05/16] rename references to dpo dataset prep to pref data (#2258) --- src/axolotl/common/datasets.py | 10 +++++----- src/axolotl/utils/data/__init__.py | 2 +- src/axolotl/utils/data/rl.py | 2 +- tests/test_datasets.py | 6 +++--- tests/test_exact_deduplication.py | 6 +++--- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index d07add29b..c693c26d8 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -11,7 +11,7 @@ from datasets import Dataset import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.utils.data import prepare_dataset -from axolotl.utils.data.rl import load_prepare_dpo_datasets +from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.tokenization import check_dataset_labels @@ -103,9 +103,9 @@ def load_preference_datasets( cli_args: Union[PreprocessCliArgs, TrainerCliArgs], ) -> TrainDatasetMeta: """ - Loads one or more training or evaluation datasets for DPO training, calling - `axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug - information. + Loads one or more training or evaluation datasets for RL training using paired + preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`. + Optionally, logs out debug information. Args: cfg: Dictionary mapping `axolotl` config keys to values. @@ -115,7 +115,7 @@ def load_preference_datasets( Dataclass with fields for training and evaluation datasets and the computed `total_num_steps`. """ - train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) + train_dataset, eval_dataset = load_prepare_preference_datasets(cfg) total_num_steps = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) diff --git a/src/axolotl/utils/data/__init__.py b/src/axolotl/utils/data/__init__.py index 140d02106..7f90bf3cb 100644 --- a/src/axolotl/utils/data/__init__.py +++ b/src/axolotl/utils/data/__init__.py @@ -5,7 +5,7 @@ from axolotl.utils.data.pretraining import ( # noqa: F401 encode_pretraining, wrap_pretraining_dataset, ) -from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401 +from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401 from axolotl.utils.data.sft import ( # noqa: F401 get_dataset_wrapper, load_prepare_datasets, diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index edb72f186..9f5c726ab 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -115,7 +115,7 @@ def drop_long_rl_seq( raise ValueError("Unknown RL type") -def load_prepare_dpo_datasets(cfg): +def load_prepare_preference_datasets(cfg): def load_split(dataset_cfgs, _cfg): split_datasets: List[Any] = [] for i, ds_cfg in enumerate(dataset_cfgs): diff --git a/tests/test_datasets.py b/tests/test_datasets.py index b1ecfd6d5..49554d370 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -17,7 +17,7 @@ from huggingface_hub import snapshot_download from transformers import AutoTokenizer from axolotl.utils.data import load_tokenized_prepared_datasets -from axolotl.utils.data.rl import load_prepare_dpo_datasets +from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.dict import DictDefault @@ -280,7 +280,7 @@ class TestDatasetPreparation(unittest.TestCase): } ) - train_dataset, _ = load_prepare_dpo_datasets(cfg) + train_dataset, _ = load_prepare_preference_datasets(cfg) assert len(train_dataset) == 1800 assert "conversation" in train_dataset.features @@ -329,7 +329,7 @@ class TestDatasetPreparation(unittest.TestCase): } ) - train_dataset, _ = load_prepare_dpo_datasets(cfg) + train_dataset, _ = load_prepare_preference_datasets(cfg) assert len(train_dataset) == 1800 assert "conversation" in train_dataset.features diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py index 2ac6415be..bc0734ed3 100644 --- a/tests/test_exact_deduplication.py +++ b/tests/test_exact_deduplication.py @@ -12,7 +12,7 @@ from datasets import Dataset from transformers import AutoTokenizer from axolotl.utils.data import prepare_dataset -from axolotl.utils.data.rl import load_prepare_dpo_datasets +from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.utils import deduplicate_and_log_datasets from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_processor, load_tokenizer @@ -236,7 +236,7 @@ class TestDeduplicateRLDataset(unittest.TestCase): """Verify that loading with deduplication removes duplicates.""" # Load the dataset using the deduplication setting - train_dataset, _ = load_prepare_dpo_datasets(self.cfg) + train_dataset, _ = load_prepare_preference_datasets(self.cfg) # Verify that the dataset has been deduplicated assert len(train_dataset) == 1800, "Dataset was not properly deduplicated" @@ -245,7 +245,7 @@ class TestDeduplicateRLDataset(unittest.TestCase): """Verify that loading without deduplication retains duplicates.""" self.cfg.dataset_exact_deduplication = False # Load the dataset without deduplication - train_dataset, _ = load_prepare_dpo_datasets(self.cfg) + train_dataset, _ = load_prepare_preference_datasets(self.cfg) # Verify that the dataset retains duplicates assert ( From cba5a457d9541a1ffde6a99977bff575c4899966 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 15 Jan 2025 10:08:56 +0700 Subject: [PATCH 06/16] fix: use text_column even when not packing for pretraining (#2254) * fix: use text_column even when not packing for pretraining * feat: update test to check when not packing * chore: lint * Update src/axolotl/utils/data/pretraining.py Co-authored-by: Wing Lian --------- Co-authored-by: Wing Lian Co-authored-by: Wing Lian --- src/axolotl/utils/data/pretraining.py | 14 +++++++++++--- tests/e2e/test_llama_pretrain.py | 16 ++++++++++------ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index f493db70e..369d2d6fe 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -18,10 +18,13 @@ LOG = logging.getLogger("axolotl") def encode_pretraining( - tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List] + tokenizer: PreTrainedTokenizerBase, + max_tokens: int, + examples: Dict[str, List], + text_column: str = "text", ) -> Dict[str, List]: res = tokenizer( - examples["text"], + examples[text_column], truncation=True, max_length=max_tokens - 2, add_special_tokens=True, @@ -196,7 +199,12 @@ def wrap_pretraining_dataset( # set this to 1 so downstream data_loader doesn't try to increase the batch again cfg.micro_batch_size = 1 else: - encode = functools.partial(encode_pretraining, tokenizer, max_tokens) + encode = functools.partial( + encode_pretraining, + tokenizer, + max_tokens, + text_column=cfg.pretraining_dataset[0].text_column or "text", + ) if cfg.shuffle_merged_datasets: dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index 68cd490be..117eba25d 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -4,7 +4,8 @@ E2E tests for llama pretrain import logging import os -import unittest + +import pytest from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets @@ -12,19 +13,22 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import check_model_output_exists, with_temp_dir +from .utils import check_model_output_exists LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" -class TestPretrainLlama(unittest.TestCase): +class TestPretrainLlama: """ Test case for Llama models w pretraining """ - @with_temp_dir - def test_pretrain_w_sample_packing(self, temp_dir): + @pytest.mark.parametrize( + "sample_packing", + [True, False], + ) + def test_pretrain(self, temp_dir, sample_packing): # pylint: disable=duplicate-code cfg = DictDefault( { @@ -32,7 +36,7 @@ class TestPretrainLlama(unittest.TestCase): "tokenizer_type": "LlamaTokenizer", "flash_attention": True, "sequence_len": 1024, - "sample_packing": True, + "sample_packing": sample_packing, "special_tokens": { "unk_token": "", "bos_token": "", From 860609392184cf62a7e0ca676658b170e059ce6c Mon Sep 17 00:00:00 2001 From: jwongTensora Date: Wed, 15 Jan 2025 03:09:29 +0000 Subject: [PATCH 07/16] fix for indexing error from token/embeddings mismatch (#2257) Co-authored-by: jwong --- src/axolotl/utils/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 523fd76fe..4a665c111 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1057,7 +1057,7 @@ class ModelLoader: ) if ( hasattr(self.model, "get_input_embeddings") - and self.model.get_input_embeddings().num_embeddings < embeddings_len + and self.model.get_input_embeddings().num_embeddings != embeddings_len ): resize_kwargs = {} if self.cfg.mean_resizing_embeddings is not None: From af727eedf75518bc603545b03a54a28fa99beeec Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 20 Jan 2025 14:07:34 -0500 Subject: [PATCH 08/16] option to not concatenate during pretraining (#2263) * option to not concatenate during pretraining * simplify conditional and add doc to config.qmd --- docs/config.qmd | 2 ++ src/axolotl/core/trainer_builder.py | 2 ++ src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 6 ++++++ src/axolotl/utils/data/pretraining.py | 9 +++++++++ 4 files changed, 19 insertions(+) diff --git a/docs/config.qmd b/docs/config.qmd index 70679791e..179ee9ed1 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -244,6 +244,8 @@ total_num_tokens: sample_packing_group_size: 100000 # The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples. sample_packing_bin_size: 200 +# whether to concatenate samples during pretraining +pretraining_sample_concatenation: # Use batch flattening for speedups when not using sample_packing batch_flattening: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 176ce4174..6f1bae1ef 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1877,6 +1877,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs ): if training_args.pretraining: + if self.cfg.pretraining_sample_concatenation is False: + return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return None if self.cfg.model_config_type == "mamba": diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 4f368994a..98cdee009 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -706,6 +706,12 @@ class AxolotlInputConfig( pad_to_sequence_len: Optional[bool] = None curriculum_sampling: Optional[bool] = None multipack_real_batches: Optional[bool] = None + pretraining_sample_concatenation: Optional[bool] = Field( + default=None, + json_schema_extra={ + "description": "whether to soft pack/concatenate samples during pretraining", + }, + ) batch_flattening: Optional[Union[Literal["auto"], bool]] = None diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index 369d2d6fe..c30d62575 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -22,6 +22,7 @@ def encode_pretraining( max_tokens: int, examples: Dict[str, List], text_column: str = "text", + concatenate: bool = True, ) -> Dict[str, List]: res = tokenizer( examples[text_column], @@ -33,6 +34,13 @@ def encode_pretraining( input_ids = [torch.tensor(seq) for seq in res["input_ids"]] targets = [torch.tensor(seq) for seq in res["input_ids"]] attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] + if not concatenate: + return { + "input_ids": [seq.tolist() for seq in input_ids], + "labels": [seq.tolist() for seq in targets], + "attention_mask": [seq.tolist() for seq in attention_mask], + } + new_input_ids = [] new_labels = [] new_attention_mask = [] @@ -204,6 +212,7 @@ def wrap_pretraining_dataset( tokenizer, max_tokens, text_column=cfg.pretraining_dataset[0].text_column or "text", + concatenate=cfg.pretraining_sample_concatenation is True, ) if cfg.shuffle_merged_datasets: From bb9d4102c4d11d3129d88b8b563c2d03c4b1f985 Mon Sep 17 00:00:00 2001 From: Adithya Kamath Date: Wed, 22 Jan 2025 02:09:17 +0530 Subject: [PATCH 09/16] Add 5000 line history limit to tmux for docker cloud (#2268) --- docker/Dockerfile-cloud | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docker/Dockerfile-cloud b/docker/Dockerfile-cloud index c8249cb79..735afa4dd 100644 --- a/docker/Dockerfile-cloud +++ b/docker/Dockerfile-cloud @@ -20,7 +20,8 @@ RUN apt install --yes --no-install-recommends openssh-server tmux && \ printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \ printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \ chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \ - chmod +x /root/cloud-entrypoint.sh + chmod +x /root/cloud-entrypoint.sh && \ + echo 'set-option -g history-limit 5000' >> ~/.tmux.conf ENTRYPOINT ["/root/cloud-entrypoint.sh"] CMD ["sleep", "infinity"] From 8fb72cbc0b94129141bae5fa4d84edd23b648af6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 21 Jan 2025 15:39:30 -0500 Subject: [PATCH 10/16] use the extracted field_messages to parse the role fields (#2265) --- scripts/chat_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/chat_datasets.py b/scripts/chat_datasets.py index 5eb5bde1e..6210b1138 100644 --- a/scripts/chat_datasets.py +++ b/scripts/chat_datasets.py @@ -30,7 +30,7 @@ def parse_dataset(dataset=None, split="train"): ) ds_cfg["field_messages"] = field_messages - message_fields = features["conversations"][0].keys() + message_fields = features[field_messages][0].keys() message_field_role = None for key in ["from", "role"]: if key in message_fields: From 8a7a0b07dc5ce6da9171e28a0818b447b6d7cea2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 23 Jan 2025 21:17:57 -0500 Subject: [PATCH 11/16] support for latest transformers release 4.48.1 (#2256) --- cicd/cicd.sh | 3 +- requirements.txt | 4 +- src/axolotl/core/trainer_builder.py | 1 + src/axolotl/monkeypatch/trainer_grad_accum.py | 308 ------------------ .../monkeypatch/transformers_fa_utils.py | 67 ++++ src/axolotl/utils/models.py | 18 +- tests/e2e/multigpu/test_llama.py | 16 +- tests/e2e/patched/test_mixtral_samplepack.py | 6 +- tests/e2e/patched/test_model_patches.py | 7 +- tests/e2e/patched/test_unsloth_integration.py | 4 +- tests/e2e/solo/__init__.py | 0 tests/e2e/{ => solo}/test_relora_llama.py | 2 +- tests/patched/test_llama_trainer_ga.py | 25 -- 13 files changed, 98 insertions(+), 363 deletions(-) delete mode 100644 src/axolotl/monkeypatch/trainer_grad_accum.py create mode 100644 src/axolotl/monkeypatch/transformers_fa_utils.py create mode 100644 tests/e2e/solo/__init__.py rename tests/e2e/{ => solo}/test_relora_llama.py (97%) delete mode 100644 tests/patched/test_llama_trainer_ga.py diff --git a/cicd/cicd.sh b/cicd/cicd.sh index 91926127f..34a30db44 100755 --- a/cicd/cicd.sh +++ b/cicd/cicd.sh @@ -6,5 +6,6 @@ python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__" pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ # pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/ +pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/ -pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ +pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ diff --git a/requirements.txt b/requirements.txt index 1f7ac7bba..52e146411 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,9 +13,9 @@ liger-kernel==0.5.2 packaging==23.2 peft==0.14.0 -transformers==4.47.1 +transformers==4.48.1 tokenizers>=0.21.0 -accelerate==1.2.1 +accelerate==1.3.0 datasets==3.2.0 deepspeed==0.16.1 trl==0.13.0 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 6f1bae1ef..edc842994 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1079,6 +1079,7 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): super().__init__(*args, **kwargs) self.dataset_tags = dataset_tags self.optimizer = None + self.model_accepts_loss_kwargs = False def create_optimizer(self): if self.args.loraplus_lr_ratio is None: diff --git a/src/axolotl/monkeypatch/trainer_grad_accum.py b/src/axolotl/monkeypatch/trainer_grad_accum.py deleted file mode 100644 index 05d706704..000000000 --- a/src/axolotl/monkeypatch/trainer_grad_accum.py +++ /dev/null @@ -1,308 +0,0 @@ -""" -fix for FSDP gradient accumulation -see https://github.com/huggingface/transformers/pull/35128 -""" -import inspect -import logging - -from transformers import LlamaForCausalLM, Trainer -from transformers.modeling_flash_attention_utils import _flash_attention_forward - -from axolotl.monkeypatch.utils import detab_code - -LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum") - -ORIGINAL_CONTEXT_CODE = """ - with self.compute_loss_context_manager(): - loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) -""" - -PATCHED_CONTEXT_CODE = """ - with self.compute_loss_context_manager(): - if self.model_accepts_loss_kwargs: - loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) - else: - loss = self.compute_loss(model, inputs) -""" - -ORIGINAL_LLAMA_FCLM_CODE = """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) -""" - -PATCHED_LLAMA_FCLM_CODE = """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention - num_items_in_batch = kwargs.pop("num_items_in_batch", None) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs) -""" - - -def get_training_step_code() -> str: - training_step = inspect.getsource( - Trainer.training_step # pylint: disable=protected-access - ) - return training_step - - -def check_training_step_is_patchable() -> bool: - training_step = get_training_step_code() - training_step, _ = detab_code(training_step) - return ORIGINAL_CONTEXT_CODE in training_step - - -def patch_training_step_for_ga(): - """ - monkeypatch for fixing the training loop for gradient accumulation - """ - - try: - training_step = get_training_step_code() - except OSError: - return - Trainer._original_training_step = training_step # pylint: disable=protected-access - training_step, _ = detab_code(training_step) - if ORIGINAL_CONTEXT_CODE not in training_step: - return - # assert ( - # ORIGINAL_CONTEXT_CODE in training_step - # ), "Original training_step code not found" - - training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE) - training_step = training_step.replace( - "def training_step(", - "def _fixed_training_step(", - 1, - ) - - # load imports necessary - import transformers.trainer - - items_to_import = [] - for item in dir(transformers.trainer): - if item in training_step: - items_to_import.append(item) - - exec( # pylint: disable=exec-used # nosec B102 - "from transformers.trainer import (" - + ", ".join(x for x in items_to_import) - + ")", - globals(), - ) - exec(training_step, globals()) # pylint: disable=exec-used # nosec B102 - LOG.info("patching training_step") - Trainer.training_step = ( # pylint: disable=protected-access - _fixed_training_step # pylint: disable=undefined-variable # noqa: F821 - ) - - -def get_model_forward_code() -> str: - forward = inspect.getsource( - LlamaForCausalLM.forward # pylint: disable=protected-access - ) - return forward - - -def check_forward_is_patchable() -> bool: - forward = get_model_forward_code() - forward, _ = detab_code(forward) - return ORIGINAL_LLAMA_FCLM_CODE in forward - - -def patch_forward_for_ga(): - """ - monkeypatch for fixing the training loop for gradient accumulation - """ - - try: - forward = get_model_forward_code() - except OSError: - return - LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access - forward, _ = detab_code(forward) - if ORIGINAL_LLAMA_FCLM_CODE not in forward: - return - # assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found" - - forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE) - forward = forward.replace( - "def forward(", - "def _fixed_forward(", - 1, - ) - - # load imports necessary - import transformers.models.llama.modeling_llama - - items_to_import = [] - for item in dir(transformers.models.llama.modeling_llama): - if item in forward: - items_to_import.append(item) - - exec( # pylint: disable=exec-used # nosec B102 - "from transformers.models.llama.modeling_llama import (" - + ", ".join(x for x in items_to_import) - + ")", - globals(), - ) - exec(forward, globals()) # pylint: disable=exec-used # nosec B102 - LOG.info("patching forward") - LlamaForCausalLM.forward = ( # pylint: disable=protected-access - _fixed_forward # pylint: disable=undefined-variable # noqa: F821 - ) - - -ORIGINAL_TRAINER_CODE = """ - context = ( - functools.partial(self.accelerator.no_sync, model=model) - if i != len(batch_samples) - 1 - else contextlib.nullcontext - ) - with context(): - tr_loss_step = self.training_step(model, inputs, num_items_in_batch) -""" - -PATCHED_TRAINER_CODE = """ - disable_deepspeed_no_sync = ( - self.accelerator.distributed_type == DistributedType.DEEPSPEED - # and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients() - ) - context = ( - functools.partial(self.accelerator.no_sync, model=model) - if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync - else contextlib.nullcontext - ) - with context(): - tr_loss_step = self.training_step(model, inputs, num_items_in_batch) -""" - - -def get_training_loop_code() -> str: - training_loop = inspect.getsource( - Trainer._inner_training_loop # pylint: disable=protected-access - ) - return training_loop - - -def check_training_loop_is_patchable() -> bool: - training_loop = get_training_loop_code() - training_loop, _ = detab_code(training_loop) - return ORIGINAL_TRAINER_CODE in training_loop - - -def patch_training_loop_for_deepspeed_0_16_x(): - """ - monkeypatch for fixing the training loop for deepspeed GA - - see https://github.com/huggingface/transformers/pull/35157 - """ - - try: - training_loop = get_training_loop_code() - except OSError: - return - Trainer._original_inner_training_loop = ( # pylint: disable=protected-access - training_loop - ) - training_loop, _ = detab_code(training_loop) - if ORIGINAL_TRAINER_CODE not in training_loop: - return - - training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE) - training_loop = training_loop.replace( - "def _inner_training_loop(", - "def _fixed_inner_training_loop(", - 1, - ) - - # load imports necessary - import transformers.trainer - - items_to_import = [] - for item in dir(transformers.trainer): - if item in training_loop: - items_to_import.append(item) - - exec( # pylint: disable=exec-used # nosec B102 - "from transformers.trainer import (" - + ", ".join(x for x in items_to_import) - + ")", - globals(), - ) - exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102 - LOG.info("patching _inner_training_loop for fsdp optimizer save") - Trainer._inner_training_loop = ( # pylint: disable=protected-access - _fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821 - ) - - -def patch_flash_attention_forward(): - """ - monkeypatch for fixing the forward pass for flash attention to ignore num_items_in_batch - """ - - import transformers.modeling_flash_attention_utils - - def proxy_flash_attention_forward(*args, **kwargs): - kwargs.pop("num_items_in_batch", None) - - return _flash_attention_forward(*args, **kwargs) - - transformers.modeling_flash_attention_utils._flash_attention_forward = ( # pylint: disable=protected-access - proxy_flash_attention_forward - ) - transformers.models.llama.modeling_llama._flash_attention_forward = ( # pylint: disable=protected-access - proxy_flash_attention_forward - ) diff --git a/src/axolotl/monkeypatch/transformers_fa_utils.py b/src/axolotl/monkeypatch/transformers_fa_utils.py new file mode 100644 index 000000000..f34ecb8c0 --- /dev/null +++ b/src/axolotl/monkeypatch/transformers_fa_utils.py @@ -0,0 +1,67 @@ +""" +see https://github.com/huggingface/transformers/pull/35834 +""" + +import logging +from functools import partial +from typing import Optional + +import torch + +logger = logging.getLogger(__name__) + + +def fixed_fa_peft_integration_check( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + target_dtype: Optional[torch.dtype] = None, + preferred_dtype: Optional[torch.dtype] = None, +): + """ + PEFT usually casts the layer norms in float32 for training stability reasons + therefore the input hidden states gets silently casted in float32. Hence, we need + cast them back in float16 / bfloat16 just to be sure everything works as expected. + This might slowdown training & inference so it is recommended to not cast the LayerNorms! + + Args: + query (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value (`torch.Tensor`): + Input value states to be passed to Flash Attention API + target_dtype (`torch.dtype`, *optional*): + The dtype to convert the attention tensors to. Conversion can be ignored by + not providing the target dtype. + preferred_dtype (`torch.dtype`, *optional*): + The preferred dtype to convert the attention tensors to regardless of the + target dtype. + """ + if target_dtype is None and preferred_dtype is None: + return query, key, value + + if preferred_dtype and target_dtype != preferred_dtype: + target_dtype = preferred_dtype + + # check if any of query, key, or value are in float32. If so, cast them back to target dtype. + if any(module.dtype == torch.float32 for module in [query, key, value]): + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + return query, key, value + + +def patch_fa_peft_integration(): + import transformers.modeling_flash_attention_utils + + transformers.modeling_flash_attention_utils.fa_peft_integration_check = partial( + fixed_fa_peft_integration_check, preferred_dtype=None + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4a665c111..c4b8f05b9 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -380,23 +380,19 @@ class ModelLoader: plugin_manager = PluginManager.get_instance() plugin_manager.pre_model_load(self.cfg) + if self.cfg.adapter: + from axolotl.monkeypatch.transformers_fa_utils import ( + patch_fa_peft_integration, + ) + + patch_fa_peft_integration() + if self.cfg.gradient_checkpointing == "unsloth": transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper if self.cfg.flash_attention: self.patch_attention() - if self.cfg.model_config_type == "llama": - from axolotl.monkeypatch.trainer_grad_accum import ( - patch_flash_attention_forward, - patch_forward_for_ga, - patch_training_step_for_ga, - ) - - patch_flash_attention_forward() - patch_forward_for_ga() - patch_training_step_for_ga() - if self.cfg.sample_packing and self.cfg.s2_attention: raise ValueError( "Received `sample_packing=true` and `s2_attention=true`; however, \ diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 7135ad805..bdbd99587 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -63,6 +63,7 @@ class TestMultiGPULlama: "lr_scheduler": "cosine", "flash_attention": True, "use_tensorboard": True, + "bf16": True, } ) @@ -127,6 +128,7 @@ class TestMultiGPULlama: "lr_scheduler": "cosine", "flash_attention": True, "use_tensorboard": True, + "bf16": True, } ) @@ -201,6 +203,7 @@ class TestMultiGPULlama: "lr_scheduler": "cosine", "flash_attention": True, "use_tensorboard": True, + "bf16": True, } ) @@ -223,8 +226,12 @@ class TestMultiGPULlama: ] ) + loss_threshold = 2.3 check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", + "train/train_loss", + loss_threshold, + "Train Loss is too high", ) def test_dpo_qlora_ddp(self, temp_dir): @@ -275,6 +282,7 @@ class TestMultiGPULlama: "lr_scheduler": "cosine", "flash_attention": True, "use_tensorboard": True, + "bf16": True, } ) @@ -297,8 +305,12 @@ class TestMultiGPULlama: ] ) + loss_threshold = 2.3 check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", + "train/train_loss", + loss_threshold, + "Train Loss is too high", ) @pytest.mark.parametrize( diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index 156dac7e8..8746c923b 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -102,9 +102,5 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, dataset_meta=dataset_meta) - assert ( - "MixtralFlashAttention2" - in model.model.layers[0].self_attn.__class__.__name__ - ) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py index 78b01be64..c6a13af19 100644 --- a/tests/e2e/patched/test_model_patches.py +++ b/tests/e2e/patched/test_model_patches.py @@ -49,12 +49,7 @@ class TestModelPatches(unittest.TestCase): ) normalize_config(cfg) tokenizer = load_tokenizer(cfg) - model, _ = load_model(cfg, tokenizer, inference=False) - - assert ( - "MixtralFlashAttention2" - in model.model.layers[0].self_attn.__class__.__name__ - ) + load_model(cfg, tokenizer, inference=False) @with_temp_dir def test_mistral_multipack(self, temp_dir): diff --git a/tests/e2e/patched/test_unsloth_integration.py b/tests/e2e/patched/test_unsloth_integration.py index bc6476dab..403d26147 100644 --- a/tests/e2e/patched/test_unsloth_integration.py +++ b/tests/e2e/patched/test_unsloth_integration.py @@ -3,8 +3,6 @@ import unittest import pytest -from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable - @pytest.mark.skip( reason="Unsloth integration will be broken going into latest transformers" @@ -13,6 +11,8 @@ class TestUnslothIntegration(unittest.TestCase): """Unsloth monkeypatch integration tests.""" def test_is_self_attn_patchable(self): + from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable + # ensures the current version of transformers has loss code that matches our patching code self.assertTrue( check_self_attn_is_patchable(), diff --git a/tests/e2e/solo/__init__.py b/tests/e2e/solo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/e2e/test_relora_llama.py b/tests/e2e/solo/test_relora_llama.py similarity index 97% rename from tests/e2e/test_relora_llama.py rename to tests/e2e/solo/test_relora_llama.py index 6c785dc86..191f76f64 100644 --- a/tests/e2e/test_relora_llama.py +++ b/tests/e2e/solo/test_relora_llama.py @@ -13,7 +13,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import check_model_output_exists, check_tensorboard, with_temp_dir +from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/patched/test_llama_trainer_ga.py b/tests/patched/test_llama_trainer_ga.py deleted file mode 100644 index 58c229cf3..000000000 --- a/tests/patched/test_llama_trainer_ga.py +++ /dev/null @@ -1,25 +0,0 @@ -""""Test module for checking whether the Hugging Face Transformers is working as expected.""" -import unittest - -from axolotl.monkeypatch.trainer_grad_accum import ( - check_forward_is_patchable, - check_training_step_is_patchable, -) - - -class TestTrainerGAIntegration(unittest.TestCase): - """llama monkeypatch integration tests.""" - - def test_train_step_patchable(self): - # ensures the current version of transformers has loss code that matches our patching code - self.assertTrue( - check_training_step_is_patchable(), - "HF transformers Trainer.training_step has changed and isn't patchable", - ) - - def test_model_forward_patchable(self): - # ensures the current version of transformers has loss code that matches our patching code - self.assertTrue( - check_forward_is_patchable(), - "HF transformers LlamaForCausalLM.forward has changed and isn't patchable", - ) From 74f9782fc38884b53c444594da87fdb182a139df Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 24 Jan 2025 22:05:58 +0700 Subject: [PATCH 12/16] chore(doc): fix explanation on gcs creds retrieval (#2272) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6afcaf146..ff77b982e 100644 --- a/README.md +++ b/README.md @@ -519,8 +519,8 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod train_on_split: validation # loading from s3 or gcs - # s3 creds will be loaded from the system default and gcs only supports public access - - path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs. + # s3 creds will be loaded from the system default / gcs will attempt to load from gcloud creds, google metadata service, or anon + - path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above ... # Loading Data From a Public URL From b2774af66c64fe07e50d648c08b1446629f0da85 Mon Sep 17 00:00:00 2001 From: mashdragon <122402293+mashdragon@users.noreply.github.com> Date: Fri, 24 Jan 2025 15:06:50 +0000 Subject: [PATCH 13/16] Take `split` param from config in all load_dataset instances (#2281) --- src/axolotl/utils/data/shared.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index d14496d96..e4f31a184 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -107,6 +107,13 @@ def load_dataset_w_config(config_dataset, auth_token): except (FileNotFoundError, ConnectionError): pass + # gather extra args from the config + load_ds_kwargs = {} + if config_dataset.split: + load_ds_kwargs["split"] = config_dataset.split + else: + load_ds_kwargs["split"] = None + # prefer local dataset, even if hub exists local_path = Path(config_dataset.path) if local_path.exists(): @@ -118,7 +125,7 @@ def load_dataset_w_config(config_dataset, auth_token): name=config_dataset.name, data_files=config_dataset.data_files, streaming=False, - split=None, + **load_ds_kwargs, ) else: try: @@ -130,7 +137,7 @@ def load_dataset_w_config(config_dataset, auth_token): config_dataset.path, name=config_dataset.name, streaming=False, - split=None, + **load_ds_kwargs, ) elif local_path.is_file(): ds_type = get_ds_type(config_dataset) @@ -140,16 +147,13 @@ def load_dataset_w_config(config_dataset, auth_token): name=config_dataset.name, data_files=config_dataset.path, streaming=False, - split=None, + **load_ds_kwargs, ) else: raise ValueError( "unhandled dataset load: local path exists, but is neither a directory or a file" ) elif ds_from_hub: - load_ds_kwargs = {} - if config_dataset.split: - load_ds_kwargs["split"] = config_dataset.split ds = load_dataset( config_dataset.path, name=config_dataset.name, @@ -173,9 +177,9 @@ def load_dataset_w_config(config_dataset, auth_token): name=config_dataset.name, data_files=config_dataset.path, streaming=False, - split=None, storage_options=storage_options, trust_remote_code=config_dataset.trust_remote_code, + **load_ds_kwargs, ) elif config_dataset.path.startswith("https://"): ds_type = get_ds_type(config_dataset) @@ -184,9 +188,9 @@ def load_dataset_w_config(config_dataset, auth_token): name=config_dataset.name, data_files=config_dataset.path, streaming=False, - split=None, storage_options=storage_options, trust_remote_code=config_dataset.trust_remote_code, + **load_ds_kwargs, ) else: if isinstance(config_dataset.data_files, str): @@ -214,7 +218,7 @@ def load_dataset_w_config(config_dataset, auth_token): name=config_dataset.name, data_files=fp, streaming=False, - split=None, + **load_ds_kwargs, ) if not ds: raise ValueError("unhandled dataset load") From 60861624881ab1e70579a100a8138fcda9aef0fb Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 24 Jan 2025 22:07:02 +0700 Subject: [PATCH 14/16] chore(doc): improve explanation for *_steps and *_strategy (#2270) --- docs/config.qmd | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/config.qmd b/docs/config.qmd index 179ee9ed1..f253decbe 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -360,10 +360,11 @@ warmup_ratio: 0.05 # cannot use with warmup_steps learning_rate: 0.00003 lr_quadratic_warmup: logging_steps: -eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps +eval_steps: # Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps -save_strategy: # Set to `"no"` to skip checkpoint saves -save_steps: # Leave empty to save at each epoch +eval_strategy: # Set to `"no"` to skip evaluation, `"epoch"` at end of each epoch, leave empty to infer from `eval_steps`. +save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of each epoch, `"best"` when better result is achieved, leave empty to infer from `save_steps`. +save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps save_total_limit: # Checkpoints saved at a time # Maximum number of iterations to train for. It precedes num_epochs which means that From 20620771f1002b55438eeeb941ca6bb76216b8da Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 24 Jan 2025 12:55:20 -0500 Subject: [PATCH 15/16] Pretrain multipack (#2278) * fix for pretrain with packing * fix model name and loss expected * make sure to check with micro batch size for pretraining * change loss threshholds based on parametrization * make tests smaller for CI * fix pretrain packing * fix pretrain packing test * address pr feedback --- src/axolotl/core/trainer_builder.py | 2 ++ src/axolotl/utils/data/pretraining.py | 13 +++++------ src/axolotl/utils/trainer.py | 7 ++++-- tests/e2e/test_llama_pretrain.py | 32 ++++++++++++++++++++------- tests/test_packed_pretraining.py | 9 +++++--- 5 files changed, 42 insertions(+), 21 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index edc842994..62c6a9721 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1880,6 +1880,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if training_args.pretraining: if self.cfg.pretraining_sample_concatenation is False: return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) + if self.cfg.micro_batch_size > 1: + return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return None if self.cfg.model_config_type == "mamba": diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index c30d62575..f20ced221 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -191,7 +191,7 @@ def wrap_pretraining_dataset( tokenizer, return_tensors="pt", padding=True, - pad_to_multiple_of=max_tokens * batch_size, + pad_to_multiple_of=max_tokens, multipack_attn=cfg.pretrain_multipack_attn, ) encode = functools.partial( @@ -201,8 +201,6 @@ def wrap_pretraining_dataset( max_seq_length=max_tokens, batch_size=batch_size, multipack_attn=cfg.pretrain_multipack_attn, - group_size=cfg.sample_packing_group_size, - bin_size=cfg.sample_packing_bin_size, ) # set this to 1 so downstream data_loader doesn't try to increase the batch again cfg.micro_batch_size = 1 @@ -247,9 +245,7 @@ def encode_packed_pretraining( examples: Dict[str, List], max_seq_length: int = 2048, batch_size: int = 4, - multipack_attn: Optional[bool] = False, - group_size: int = 100000, - bin_size: int = 200, + multipack_attn: Optional[bool] = True, ) -> Dict[str, List]: # pylint: disable=duplicate-code # tokenize all the examples @@ -260,6 +256,9 @@ def encode_packed_pretraining( train_dataset, max_seq_length, skip_position_ids=not multipack_attn, + # FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm + # workaround by using the position id logic for now in trainer + drop_attention_mask=multipack_attn, ) sampler = MultipackBatchSampler( @@ -267,8 +266,6 @@ def encode_packed_pretraining( lengths=get_dataset_lengths(train_dataset), batch_size=1, batch_max_len=batch_size * max_seq_length, - group_size=group_size, - bin_size=bin_size, drop_last=True, ) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 34b505ff1..bfd21703d 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -310,19 +310,22 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): def process_pretraining_datasets_for_packing( - train_dataset, sequence_len, skip_position_ids=True + train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False ): drop_long = partial(drop_long_seq, sequence_len=sequence_len) train_dataset = train_dataset.filter( drop_long, desc="Dropping Long Sequences", + load_from_cache_file=False, ) - if skip_position_ids: + if not skip_position_ids: train_dataset = train_dataset.map( add_position_ids, desc="Add position_id column (Pretraining Sample Packing)", ) + if drop_attention_mask: + train_dataset = train_dataset.remove_columns("attention_mask") return train_dataset diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index 117eba25d..c1f024b87 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -13,7 +13,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import check_model_output_exists +from .utils import check_model_output_exists, check_tensorboard LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -28,19 +28,25 @@ class TestPretrainLlama: "sample_packing", [True, False], ) - def test_pretrain(self, temp_dir, sample_packing): + @pytest.mark.parametrize( + "pretrain_multipack_attn", + [True, False], + ) + def test_pretrain(self, temp_dir, sample_packing, pretrain_multipack_attn): + if not sample_packing and pretrain_multipack_attn: + return + # pylint: disable=duplicate-code cfg = DictDefault( { - "base_model": "JackFram/llama-68m", - "tokenizer_type": "LlamaTokenizer", + "base_model": "HuggingFaceTB/SmolLM2-135M", "flash_attention": True, "sequence_len": 1024, "sample_packing": sample_packing, + "pretrain_multipack_attn": pretrain_multipack_attn, + "dataset_processes": 1, "special_tokens": { - "unk_token": "", - "bos_token": "", - "eos_token": "", + "pad_token": "<|endoftext|>", }, "pretraining_dataset": [ { @@ -51,7 +57,7 @@ class TestPretrainLlama: ], "max_steps": 5, "num_epochs": 1, - "micro_batch_size": 1, + "micro_batch_size": 2, "gradient_accumulation_steps": 1, "val_set_size": 0.0, "output_dir": temp_dir, @@ -60,6 +66,7 @@ class TestPretrainLlama: "lr_scheduler": "cosine", "save_safetensors": True, "bf16": "auto", + "use_tensorboard": True, } ) normalize_config(cfg) @@ -68,3 +75,12 @@ class TestPretrainLlama: train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) + loss_threshold = 3.5 + if sample_packing and not pretrain_multipack_attn: + loss_threshold = 6.5 + check_tensorboard( + temp_dir + "/runs", + "train/train_loss", + loss_threshold, + "Train Loss is too high", + ) diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py index fbb776aa5..9f9ae60fb 100644 --- a/tests/test_packed_pretraining.py +++ b/tests/test_packed_pretraining.py @@ -41,6 +41,7 @@ class TestPretrainingPacking(unittest.TestCase): } ], "sample_packing": True, + "pretrain_multipack_attn": True, "pad_to_sequence_len": True, "sequence_len": 2048, "micro_batch_size": 2, @@ -87,9 +88,11 @@ class TestPretrainingPacking(unittest.TestCase): assert data["labels"].shape == torch.Size( [1, original_bsz * cfg.sequence_len] ) - assert data["attention_mask"].shape == torch.Size( - [1, original_bsz * cfg.sequence_len] - ) + assert "attention_mask" not in data + # FIXME add back once we fix packing unpad/pad with attention mask + # assert data["attention_mask"].shape == torch.Size( + # [1, original_bsz * cfg.sequence_len] + # ) idx += 1 From 887513285d98132142bf5db2a74eb5e0928787f1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 24 Jan 2025 12:56:28 -0500 Subject: [PATCH 16/16] support for custom lr groups for non-embedding modules (#2213) * support for custom lr groups for non-embedding modules invert name check for group modules include lr_groups in training args additional conditional for creating optimizer fix regular params as w weight decay fix lookup and add docs * address pr feedback --- docs/lr_groups.qmd | 29 ++++ src/axolotl/core/trainer_builder.py | 142 ++++++++++++------ .../config/models/input/v0_4_1/__init__.py | 9 ++ 3 files changed, 131 insertions(+), 49 deletions(-) create mode 100644 docs/lr_groups.qmd diff --git a/docs/lr_groups.qmd b/docs/lr_groups.qmd new file mode 100644 index 000000000..52059016c --- /dev/null +++ b/docs/lr_groups.qmd @@ -0,0 +1,29 @@ +--- +title: Learning Rate Groups +description: "Setting different learning rates by module name" +--- + +## Background + +Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of +modules in a model. + +## Example + +```yaml +lr_groups: + - name: o_proj + modules: + - self_attn.o_proj.weight + lr: 1e-6 + - name: q_proj + modules: + - model.layers.2.self_attn.q_proj.weight + lr: 1e-5 + +learning_rate: 2e-5 +``` + +In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate +of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's +self attention `q_proj` module. diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 62c6a9721..d63a10e74 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -243,6 +243,10 @@ class AxolotlTrainingMixins: default=None, metadata={"help": "Scale the learning rate for the embedding layers."}, ) + lr_groups: Optional[list[dict]] = field( + default=None, + metadata={"help": "Specify learning rate groups for with different LRs."}, + ) embedding_lr: Optional[float] = field( default=None, metadata={"help": "absolute learning rate for the embedding layers."}, @@ -461,11 +465,95 @@ class AxolotlTrainer(SchedulerMixin, Trainer): ) return super()._wrap_model(model, training=training, dataloader=dataloader) + def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs): + decay_parameters = self.get_decay_parameter_names(opt_model) + params = { + "to_weight_decay": {}, # LayerNorm and bias + "embeddings": {}, # lm_head, embed_tokens, + "no_weight_decay": {}, + } + lr_groups_lookup = {} + lr_groups_learning_rates = {} + if self.args.lr_groups: + for lr_group in self.args.lr_groups: + group_name = lr_group["name"] + group_modules = lr_group["modules"] + for module in group_modules: + lr_groups_lookup[module] = group_name + lr_groups_learning_rates[group_name] = lr_group["lr"] + params[f"to_weight_decay_{group_name}"] = {} + + for name, param in opt_model.named_parameters(): + if not param.requires_grad: + continue + if name.endswith("modules_to_save.default.weight") or any( + embed_name in name for embed_name in ["embed_tokens", "lm_head"] + ): + params["embeddings"][name] = param + elif name in decay_parameters: + lr_group_modules = [ + group_modules + for group_modules in lr_groups_lookup + if group_modules in name + ] + if lr_groups_lookup and any(lr_group_modules): + lr_group_module = lr_group_modules[0] + group_name = lr_groups_lookup[lr_group_module] + params[f"to_weight_decay_{group_name}"][name] = param + else: + params["to_weight_decay"][name] = param + else: + params["no_weight_decay"][name] = param + optimizer_grouped_parameters = [] + if params["to_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["to_weight_decay"].values()), + "weight_decay": self.args.weight_decay, + "lr": optimizer_kwargs["lr"], + } + ) + if params["embeddings"]: + lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name + if self.args.embedding_lr_scale: + lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name + elif self.args.embedding_lr: + lr = self.args.embedding_lr # pylint: disable=invalid-name + optimizer_grouped_parameters.append( + { + "params": list(params["embeddings"].values()), + "weight_decay": 0.0, + "lr": lr, + } + ) + if params["no_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["no_weight_decay"].values()), + "weight_decay": 0.0, + "lr": optimizer_kwargs["lr"], + } + ) + for group_name, group_lr in lr_groups_learning_rates.items(): + if params[f"to_weight_decay_{group_name}"]: + optimizer_grouped_parameters.append( + { + "params": list( + params[f"to_weight_decay_{group_name}"].values() + ), + "weight_decay": self.args.weight_decay, + "lr": group_lr, + } + ) + + return optimizer_grouped_parameters + def create_optimizer(self): if ( self.args.loraplus_lr_ratio is None and self.args.embedding_lr_scale is None and self.args.embedding_lr is None + and self.args.lr_groups is None and self.args.alternate_optimizer not in [ "optimi_adamw", @@ -479,59 +567,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer): opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.optimizer is None: # pylint: disable=access-member-before-definition - decay_parameters = self.get_decay_parameter_names(opt_model) - params = { - "to_weight_decay": {}, # LayerNorm and bias - "embeddings": {}, # lm_head, embed_tokens, - "no_weight_decay": {}, - } - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( self.args, opt_model, ) - - for name, param in opt_model.named_parameters(): - if not param.requires_grad: - continue - if name.endswith("modules_to_save.default.weight") or any( - embed_name in name for embed_name in ["embed_tokens", "lm_head"] - ): - params["embeddings"][name] = param - elif name in decay_parameters: - params["to_weight_decay"][name] = param - else: - params["no_weight_decay"][name] = param - optimizer_grouped_parameters = [] - if params["to_weight_decay"]: - optimizer_grouped_parameters.append( - { - "params": list(params["to_weight_decay"].values()), - "weight_decay": self.args.weight_decay, - "lr": optimizer_kwargs["lr"], - } - ) - if params["embeddings"]: - lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name - if self.args.embedding_lr_scale: - lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name - elif self.args.embedding_lr: - lr = self.args.embedding_lr # pylint: disable=invalid-name - optimizer_grouped_parameters.append( - { - "params": list(params["embeddings"].values()), - "weight_decay": 0.0, - "lr": lr, - } - ) - if params["no_weight_decay"]: - optimizer_grouped_parameters.append( - { - "params": list(params["no_weight_decay"].values()), - "weight_decay": 0.0, - "lr": optimizer_kwargs["lr"], - } - ) + optimizer_grouped_parameters = self.create_optimizer_grouped_parameters( + opt_model, optimizer_kwargs + ) if self.args.loraplus_lr_ratio is not None: loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) @@ -548,6 +590,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer): elif ( self.args.embedding_lr_scale is not None or self.args.embedding_lr is not None + or self.args.lr_groups is not None ): self.optimizer = ( # pylint: disable=attribute-defined-outside-init optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) @@ -1665,6 +1708,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ] = self.cfg.loraplus_lr_embedding training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale + training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: training_arguments_kwargs["lr_scheduler_type"] = "cosine" diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 98cdee009..44e247886 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -147,6 +147,14 @@ class UserDefinedPrompterType(BaseModel): field: Optional[str] = None +class LrGroup(BaseModel): + """Custom learning rate group configuration""" + + name: str + modules: List[str] + lr: float + + class SFTDataset(BaseModel): """SFT configuration subset""" @@ -475,6 +483,7 @@ class HyperparametersConfig(BaseModel): cosine_min_lr_ratio: Optional[float] = None cosine_constant_lr_ratio: Optional[float] = None lr_div_factor: Optional[float] = None + lr_groups: Optional[List[LrGroup]] = None adam_epsilon: Optional[float] = None adam_beta1: Optional[float] = None