gracefully handle no ring-flash-attn

This commit is contained in:
Dan Saunders
2025-03-14 01:07:25 +00:00
parent 3ae47ec7de
commit cb3a9e99a3

View File

@@ -15,7 +15,6 @@ import torch.distributed as dist
import torch.nn.functional as F
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from ring_flash_attn import update_ring_flash_attn_params
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
@@ -39,6 +38,17 @@ from axolotl.utils.schedulers import (
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
try:
from ring_flash_attn import update_ring_flash_attn_params
except ImportError:
# pylint: disable=unused-argument
def update_ring_flash_attn_params(*args, **kwargs):
raise ImportError(
"ring_flash_attn is not installed. "
"Please install it with `pip install ring-flash-attn>=0.1.4`"
)
LOG = logging.getLogger(__name__)