add e2e tests for Unsloth qlora and test the builds (#2093)
* see if unsloth installs cleanly in ci * check unsloth install on regular tests, not sdist * fix ampere check exception for ci * use cached_property instead * add an e2e test for unsloth qlora * reduce seq len and mbsz to prevent oom in ci * add checks for fp16 and sdp_attention * pin unsloth to a specific release * add unsloth to docker image too * fix flash attn xentropy patch * fix loss, add check for loss when using fa_xentropy * fix special tokens for test * typo * test fa xentropy with and without gradient accum * pr feedback changes
This commit is contained in:
@@ -8,7 +8,10 @@ from packaging.version import Version as V
|
||||
|
||||
v = V(torch.__version__)
|
||||
cuda = str(torch.version.cuda)
|
||||
is_ampere = torch.cuda.get_device_capability()[0] >= 8
|
||||
try:
|
||||
is_ampere = torch.cuda.get_device_capability()[0] >= 8
|
||||
except RuntimeError:
|
||||
is_ampere = False
|
||||
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
|
||||
raise RuntimeError(f"CUDA = {cuda} not supported!")
|
||||
if v <= V("2.1.0"):
|
||||
@@ -29,5 +32,5 @@ else:
|
||||
raise RuntimeError(f"Torch = {v} too new!")
|
||||
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
||||
print(
|
||||
f'pip install unsloth-zoo && pip install --no-deps "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"'
|
||||
f'pip install unsloth-zoo==2024.11.7 && pip install --no-deps "unsloth[{x}]==2024.11.9"'
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user