gracefully handle no ring-flash-attn
This commit is contained in:
@@ -15,7 +15,6 @@ import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
@@ -39,6 +38,17 @@ from axolotl.utils.schedulers import (
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
try:
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
except ImportError:
|
||||
# pylint: disable=unused-argument
|
||||
def update_ring_flash_attn_params(*args, **kwargs):
|
||||
raise ImportError(
|
||||
"ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install ring-flash-attn>=0.1.4`"
|
||||
)
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user