Fix: modelloader handling of model_kwargs load_in*bit (#1999)
* fix: load_in_*bit not properly read * fix: load_*bit check * fix: typo * refactor: load * bit handling * feat: add test dpo lora multi-gpu * fix: turn off sample packing for dpo * fix: missing warmup_steps * fix: test to load in 8bit for lora * skip 8bit lora on h100, add 4bit lora on h100 to multi gpu tests * chore: reduce max_steps --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -9,6 +9,8 @@ from functools import wraps
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def with_temp_dir(test_func):
|
||||
@wraps(test_func)
|
||||
@@ -35,13 +37,18 @@ def most_recent_subdir(path):
|
||||
return subdir
|
||||
|
||||
|
||||
def require_torch_2_1_1(test_case):
|
||||
def require_torch_2_3_1(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.1.1
|
||||
Decorator marking a test that requires torch >= 2.3.1
|
||||
"""
|
||||
|
||||
def is_min_2_1_1():
|
||||
def is_min_2_3_1():
|
||||
torch_version = version("torch")
|
||||
return torch_version >= "2.1.1"
|
||||
return torch_version >= "2.3.1"
|
||||
|
||||
return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)
|
||||
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
|
||||
|
||||
|
||||
def is_hopper():
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
return compute_capability == (9, 0)
|
||||
|
||||
Reference in New Issue
Block a user