diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 671be4b65..919cfd654 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -6,7 +6,7 @@ on: - '**.py' - 'requirements.txt' - '.github/workflows/*.yml' - - "*.md" + - "*.[q]md" - "examples/**/*.y[a]?ml" workflow_dispatch: diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml index 91cbaf957..ab886c67f 100644 --- a/.github/workflows/multi-gpu-e2e.yml +++ b/.github/workflows/multi-gpu-e2e.yml @@ -1,6 +1,9 @@ name: docker-multigpu-tests-biweekly on: + pull_request: + paths: + - 'tests/e2e/multigpu/*.py' workflow_dispatch: schedule: - cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml index 6b35698cb..30ed397ce 100644 --- a/.github/workflows/tests-nightly.yml +++ b/.github/workflows/tests-nightly.yml @@ -25,6 +25,7 @@ jobs: fail-fast: false matrix: python_version: ["3.10", "3.11"] + pytorch_version: ["2.3.1", "2.4.0"] timeout-minutes: 20 steps: @@ -37,6 +38,10 @@ jobs: python-version: ${{ matrix.python_version }} cache: 'pip' # caching pip dependencies + - name: Install PyTorch + run: | + pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu + - name: Update requirements.txt run: | sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 74b4bcfbd..c104e92c2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,6 +36,7 @@ jobs: fail-fast: false matrix: python_version: ["3.10", "3.11"] + pytorch_version: ["2.3.1", "2.4.0"] timeout-minutes: 20 steps: @@ -48,6 +49,10 @@ jobs: python-version: ${{ matrix.python_version }} cache: 'pip' # caching pip dependencies + - name: Install PyTorch + run: | + pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu + - name: Install dependencies run: | pip3 install --upgrade pip diff --git a/README.md b/README.md index af604fad5..c84f1cb8c 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Features: - Supports fullfinetune, lora, qlora, relora, and gptq - Customize configurations using a simple yaml file or CLI overwrite - Load different dataset formats, use custom formats, or bring your own tokenized datasets -- Integrated with xformer, flash attention, rope scaling, and multipacking +- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking - Works with single GPU or multiple GPUs via FSDP or Deepspeed - Easily run with Docker locally or on the cloud - Log results and optionally checkpoints to wandb or mlflow diff --git a/_quarto.yml b/_quarto.yml index 6b2eed971..acb487258 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -37,6 +37,7 @@ website: - docs/mac.qmd - docs/multi-node.qmd - docs/unsloth.qmd + - docs/amd_hpc.qmd - section: "Dataset Formats" contents: docs/dataset-formats/* - section: "Reference" diff --git a/cicd/cicd.sh b/cicd/cicd.sh index eceda9b37..104a8f84a 100755 --- a/cicd/cicd.sh +++ b/cicd/cicd.sh @@ -2,5 +2,5 @@ set -e pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ -pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ -pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ /workspace/axolotl/tests/e2e/ +pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ +pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ diff --git a/docs/amd_hpc.qmd b/docs/amd_hpc.qmd new file mode 100644 index 000000000..d1c274e15 --- /dev/null +++ b/docs/amd_hpc.qmd @@ -0,0 +1,108 @@ +--- +title: Training with AMD GPUs on HPC Systems +description: A comprehensive guide for using Axolotl on distributed systems with AMD GPUs +--- + +This guide provides step-by-step instructions for installing and configuring Axolotl on a High-Performance Computing (HPC) environment equipped with AMD GPUs. + +## Setup + +### 1. Install Python + +We recommend using Miniforge, a minimal conda-based Python distribution: + +```bash +curl -L -O "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh" +bash Miniforge3-$(uname)-$(uname -m).sh +``` + +### 2. Configure Python Environment +Add Python to your PATH and ensure it's available at login: + +```bash +echo 'export PATH=~/miniforge3/bin:$PATH' >> ~/.bashrc +echo 'if [ -f ~/.bashrc ]; then . ~/.bashrc; fi' >> ~/.bash_profile +``` + +### 3. Load AMD GPU Software + +Load the ROCm module: + +```bash +module load rocm/5.7.1 +``` + +Note: The specific module name and version may vary depending on your HPC system. Consult your system documentation for the correct module name. + +### 4. Install PyTorch + +Install PyTorch with ROCm support: + +```bash +pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7 --force-reinstall +``` + +### 5. Install Flash Attention + +Clone and install the Flash Attention repository: + +```bash +git clone --recursive https://github.com/ROCmSoftwarePlatform/flash-attention.git +export GPU_ARCHS="gfx90a" +cd flash-attention +export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])') +patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch +pip install . +``` + +### 6. Install Axolotl + +Clone and install Axolotl: + +```bash +git clone https://github.com/axolotl-ai-cloud/axolotl +cd axolotl +pip install packaging ninja +pip install -e . +``` + +### 7. Apply xformers Workaround + +xformers appears to be incompatible with ROCm. Apply the following workarounds: + - Edit $HOME/packages/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py modifying the code to always return `False` for SwiGLU availability from xformers. + - Edit $HOME/miniforge3/lib/python3.10/site-packages/xformers/ops/swiglu_op.py replacing the "SwiGLU" function with a pass statement. + +### 8. Prepare Job Submission Script + +Create a script for job submission using your HPC's particular software (e.g. Slurm, PBS). Include necessary environment setup and the command to run Axolotl training. If the compute node(s) do(es) not have internet access, it is recommended to include + +```bash +export TRANSFORMERS_OFFLINE=1 +export HF_DATASETS_OFFLINE=1 +``` + +### 9. Download Base Model + +Download a base model using the Hugging Face CLI: + +```bash +huggingface-cli download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B +``` + +### 10. Create Axolotl Configuration + +Create an Axolotl configuration file (YAML format) tailored to your specific training requirements and dataset. Use FSDP for multi-node training. + +Note: Deepspeed did not work at the time of testing. However, if anyone managed to get it working, please let us know. + +### 11. Preprocess Data + +Run preprocessing on the login node: + +```bash +CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess /path/to/your/config.yaml +``` + +### 12. Train + +You are now ready to submit your previously prepared job script. 🚂 diff --git a/docs/dataset-formats/tokenized.qmd b/docs/dataset-formats/tokenized.qmd index b2ea003c0..61028cae7 100644 --- a/docs/dataset-formats/tokenized.qmd +++ b/docs/dataset-formats/tokenized.qmd @@ -7,7 +7,7 @@ order: 5 - Pass an empty `type:` in your axolotl config. - Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels` - To indicate that a token should be ignored during training, set its corresponding label to `-100`. -- Do not add BOS/EOS. Axolotl will add them for you based on the default tokenizer for the model you're using. +- You must add BOS and EOS, and make sure that you are training on EOS by not setting its label to -100. - For pretraining, do not truncate/pad documents to the context window length. - For instruction training, documents must be truncated/padded as desired. diff --git a/examples/llama-3/fft-8b-liger-fsdp.yaml b/examples/llama-3/fft-8b-liger-fsdp.yaml index a64965d20..e84d221f8 100644 --- a/examples/llama-3/fft-8b-liger-fsdp.yaml +++ b/examples/llama-3/fft-8b-liger-fsdp.yaml @@ -31,7 +31,7 @@ wandb_log_model: gradient_accumulation_steps: 4 micro_batch_size: 2 num_epochs: 1 -optimizer: paged_adamw_8bit +optimizer: adamw_torch lr_scheduler: cosine learning_rate: 2e-5 diff --git a/requirements.txt b/requirements.txt index f5fb547a2..83116af60 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.12.0 -transformers==4.44.0 +transformers==4.44.2 tokenizers>=0.19.1 bitsandbytes==0.43.3 -accelerate==0.33.0 +accelerate==0.34.2 datasets==2.20.0 deepspeed==0.14.4 pydantic==2.6.3 @@ -34,7 +34,7 @@ tensorboard python-dotenv==1.0.1 autoawq>=0.2.5 triton>=2.3.0 -liger-kernel +liger-kernel==0.2.1 mamba-ssm==1.2.0.post1 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 656ded255..f4cd25783 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -506,9 +506,10 @@ class AxolotlTrainer(SchedulerMixin, Trainer): batch_max_len = self.args.max_seq_length else: batch_size = 1 - batch_max_len = ( - self.args.per_device_train_batch_size * self.args.max_seq_length + train_batch_size = ( + self.state.train_batch_size or self.args.per_device_train_batch_size ) + batch_max_len = train_batch_size * self.args.max_seq_length return MultipackBatchSampler( RandomSampler(self.train_dataset), lengths=get_dataset_lengths(self.train_dataset), @@ -1379,6 +1380,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs[ "per_device_eval_batch_size" ] = self.cfg.eval_batch_size + if self.cfg.auto_find_batch_size is not None: + training_arguments_kwargs[ + "auto_find_batch_size" + ] = self.cfg.auto_find_batch_size training_arguments_kwargs[ "gradient_accumulation_steps" ] = self.cfg.gradient_accumulation_steps @@ -1461,9 +1466,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ) training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) - training_arguments_kwargs[ - "multipack_real_batches" - ] = not self.cfg.flash_attention + training_arguments_kwargs["multipack_real_batches"] = ( + not self.cfg.flash_attention or self.cfg.multipack_real_batches + ) training_arguments_kwargs["eval_sample_packing"] = bool( self.cfg.eval_sample_packing ) diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 2a3e95163..2047f3815 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -20,10 +20,10 @@ It is designed to be performant, correct, and light-weight. """ import logging import sys +from functools import partial from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.geglu import LigerGEGLUMLP -from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.transformers.swiglu import LigerSwiGLUMLP @@ -43,6 +43,9 @@ class LigerPlugin(BasePlugin): def pre_model_load(self, cfg): if cfg.model_config_type == "llama": + from liger_kernel.transformers.model.llama import ( + lce_forward as llama_lce_forward, + ) from transformers.models.llama import modeling_llama if cfg.liger_rope: @@ -57,6 +60,9 @@ class LigerPlugin(BasePlugin): modeling_llama.LlamaForCausalLM.forward = llama_lce_forward elif cfg.model_config_type == "mistral": + from liger_kernel.transformers.model.mistral import ( + lce_forward as mistral_lce_forward, + ) from transformers.models.mistral import modeling_mistral if cfg.liger_rope: @@ -68,25 +74,26 @@ class LigerPlugin(BasePlugin): if cfg.liger_cross_entropy: modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: - logging.warning( - "Fused linear cross entropy is not supported for Mistral." - ) + modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward elif cfg.model_config_type == "gemma": + from liger_kernel.transformers.model.gemma import ( + lce_forward as gemma_lce_forward, + ) from transformers.models.gemma import modeling_gemma if cfg.liger_rope: modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb if cfg.liger_rms_norm: - modeling_gemma.GemmaRMSNorm = LigerRMSNorm + modeling_gemma.GemmaRMSNorm = partial( + LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma" + ) if cfg.liger_swiglu: modeling_gemma.GemmaMLP = LigerGEGLUMLP if cfg.liger_cross_entropy: modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: - logging.warning( - "Fused linear cross entropy is not supported for Gemma." - ) + modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward elif cfg.model_config_type == "jamba": from transformers.models.jamba import modeling_jamba @@ -145,3 +152,38 @@ class LigerPlugin(BasePlugin): modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward + + elif cfg.model_config_type == "gemma2": + from transformers.models.gemma2 import modeling_gemma2 + + if cfg.liger_rope: + modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_gemma2.Gemma2RMSNorm = partial( + LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma" + ) + if cfg.liger_swiglu: + modeling_gemma2.Gemma2MLP = LigerGEGLUMLP + if cfg.liger_cross_entropy: + modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + logging.warning( + "Fused linear cross entropy is not supported for Gemma 2." + ) + + elif cfg.model_config_type == "phi3": + from liger_kernel.transformers.model.phi3 import ( + lce_forward as phi3_lce_forward, + ) + from transformers.models.phi3 import modeling_phi3 + + if cfg.liger_rope: + modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_phi3.Phi3RMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_phi3.Phi3MLP = LigerSwiGLUMLP + if cfg.liger_cross_entropy: + modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index f147c645b..5892db1e7 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -382,6 +382,8 @@ class HyperparametersConfig(BaseModel): }, ) + auto_find_batch_size: Optional[bool] = None + train_on_inputs: Optional[bool] = False group_by_length: Optional[bool] = None @@ -619,6 +621,7 @@ class AxolotlInputConfig( eval_sample_packing: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None curriculum_sampling: Optional[bool] = None + multipack_real_batches: Optional[bool] = None # for PoSE context length extension use_pose: Optional[bool] = None diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 957ca5746..205c2894d 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -11,6 +11,8 @@ import numba import numpy as np from torch.utils.data import BatchSampler, Sampler +from axolotl.utils.distributed import reduce_and_broadcast + LOG = logging.getLogger("axolotl.utils.samplers.multipack") @@ -174,16 +176,46 @@ class MultipackBatchSampler(BatchSampler): def efficiency(self): return self.eff_total_used / self.eff_total_slots + def gather_efficiency(self): + def calc_sample_packing_eff_est(estimates: List[float]): + LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}") + return math.floor(0.997 * max(estimates)) + + sample_packing_actual_eff_all = reduce_and_broadcast( + lambda: self.efficiency(), # pylint: disable=unnecessary-lambda + calc_sample_packing_eff_est, + ) + sample_packing_eff_est = ( + math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0 + ) + return sample_packing_eff_est + + def gather_len_batches(self, num): + def calc_min_len(estimates: list[(int, float)]): + LOG.info(f"gather_len_batches: {repr(estimates)}") + return math.floor(0.998 * min(estimates)) + + min_len_batches = reduce_and_broadcast( + lambda: num, + calc_min_len, + ) + return min_len_batches + def __len__(self): - self.num_batches() - return self._len_est() + len_batches = self.num_batches() + return self.gather_len_batches(len_batches) def _len_est(self): + efficiency = ( + self.packing_efficiency_estimate + if self.packing_efficiency_estimate + else self.gather_efficiency() + ) world_size = int(os.getenv("WORLD_SIZE", "1")) lengths_sum = np.sum(self.lengths) lengths_sum_per_device = lengths_sum // world_size LOG.info( - f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " + f"packing_efficiency_estimate: {efficiency} " f"total_num_tokens per device: {lengths_sum_per_device}" ) @@ -195,7 +227,7 @@ class MultipackBatchSampler(BatchSampler): * math.floor( 0.99 * lengths_sum_per_device - / self.packing_efficiency_estimate + / efficiency // (self.batch_max_len * self.batch_size) ) - 1 diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index f4e1fc6cb..89ae4e697 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -357,7 +357,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): main_process_only=True, ) else: - if cfg.flash_attention: + if cfg.flash_attention and not cfg.multipack_real_batches: sampler_batch_size = 1 batch_max_len = cfg.micro_batch_size * cfg.sequence_len else: @@ -425,7 +425,8 @@ def setup_deepspeed_env(cfg, stage=None): os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage) if stage == 3: os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true" - HfTrainerDeepSpeedConfig(cfg.deepspeed) + # If we don't assign this, it doesn't actually get set in the accelerate weakref + _ = HfTrainerDeepSpeedConfig(cfg.deepspeed) def setup_fsdp_envs(cfg): diff --git a/tests/e2e/integrations/__init__.py b/tests/e2e/integrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/e2e/integrations/liger.py b/tests/e2e/integrations/liger.py new file mode 100644 index 000000000..4497cebe3 --- /dev/null +++ b/tests/e2e/integrations/liger.py @@ -0,0 +1,110 @@ +""" +Simple end-to-end test for Liger integration +""" + +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from ..utils import with_temp_dir + + +class LigerIntegrationTestCase(unittest.TestCase): + """ + e2e tests for liger integration with Axolotl + """ + + @with_temp_dir + def test_llama_wo_flce(self, temp_dir): + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "plugins": [ + "axolotl.integrations.liger.LigerPlugin", + ], + "liger_rope": True, + "liger_rms_norm": True, + "liger_swiglu": True, + "liger_cross_entropy": True, + "liger_fused_linear_cross_entropy": False, + "sequence_len": 1024, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists() + + @with_temp_dir + def test_llama_w_flce(self, temp_dir): + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "plugins": [ + "axolotl.integrations.liger.LigerPlugin", + ], + "liger_rope": True, + "liger_rms_norm": True, + "liger_swiglu": True, + "liger_cross_entropy": False, + "liger_fused_linear_cross_entropy": True, + "sequence_len": 1024, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists() diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 344c57fb8..61bb8ed32 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -10,6 +10,7 @@ from pathlib import Path import pytest import yaml from accelerate.test_utils import execute_subprocess_async +from huggingface_hub import snapshot_download from axolotl.utils.dict import DictDefault @@ -19,6 +20,12 @@ LOG = logging.getLogger("axolotl.tests.e2e.multigpu") os.environ["WANDB_DISABLED"] = "true" +@pytest.fixture(scope="session", autouse=True) +def download_model(): + # download the model + snapshot_download("TinyLlama/TinyLlama_v1.1") + + class TestMultiGPULlama(unittest.TestCase): """ Test case for Llama models using LoRA