diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 109017145..e2b1ccc2b 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -15,7 +15,6 @@ import torch.distributed as dist import torch.nn.functional as F from datasets import Dataset from peft.optimizers import create_loraplus_optimizer -from ring_flash_attn import update_ring_flash_attn_params from torch import nn from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler @@ -39,6 +38,17 @@ from axolotl.utils.schedulers import ( if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp +try: + from ring_flash_attn import update_ring_flash_attn_params +except ImportError: + # pylint: disable=unused-argument + def update_ring_flash_attn_params(*args, **kwargs): + raise ImportError( + "ring_flash_attn is not installed. " + "Please install it with `pip install ring-flash-attn>=0.1.4`" + ) + + LOG = logging.getLogger(__name__)