From cb3a9e99a33679415adcf8ac03adba04a7b189b0 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 14 Mar 2025 01:07:25 +0000 Subject: [PATCH] gracefully handle no ring-flash-attn --- src/axolotl/core/trainers/base.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 109017145..e2b1ccc2b 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -15,7 +15,6 @@ import torch.distributed as dist import torch.nn.functional as F from datasets import Dataset from peft.optimizers import create_loraplus_optimizer -from ring_flash_attn import update_ring_flash_attn_params from torch import nn from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler @@ -39,6 +38,17 @@ from axolotl.utils.schedulers import ( if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp +try: + from ring_flash_attn import update_ring_flash_attn_params +except ImportError: + # pylint: disable=unused-argument + def update_ring_flash_attn_params(*args, **kwargs): + raise ImportError( + "ring_flash_attn is not installed. " + "Please install it with `pip install ring-flash-attn>=0.1.4`" + ) + + LOG = logging.getLogger(__name__)