gracefully handle no ring-flash-attn
This commit is contained in:
@@ -15,7 +15,6 @@ import torch.distributed as dist
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from ring_flash_attn import update_ring_flash_attn_params
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
@@ -39,6 +38,17 @@ from axolotl.utils.schedulers import (
|
|||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
import smdistributed.modelparallel.torch as smp
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ring_flash_attn import update_ring_flash_attn_params
|
||||||
|
except ImportError:
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
def update_ring_flash_attn_params(*args, **kwargs):
|
||||||
|
raise ImportError(
|
||||||
|
"ring_flash_attn is not installed. "
|
||||||
|
"Please install it with `pip install ring-flash-attn>=0.1.4`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user