Compare commits
2 Commits
hymba_mult
...
debug-hf-h
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
59047ee6c4 | ||
|
|
c1b920f291 |
@@ -23,7 +23,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/PyCQA/pylint
|
- repo: https://github.com/PyCQA/pylint
|
||||||
rev: v2.17.4
|
rev: v3.3.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[MASTER]
|
[MASTER]
|
||||||
init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))"
|
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
|
||||||
|
|
||||||
[TYPECHECK]
|
[TYPECHECK]
|
||||||
|
|
||||||
@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
|
|||||||
disable=missing-function-docstring, line-too-long, import-error,
|
disable=missing-function-docstring, line-too-long, import-error,
|
||||||
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
||||||
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
||||||
|
too-many-positional-arguments, possibly-used-before-assignment
|
||||||
|
|||||||
23
setup.py
23
setup.py
@@ -1,4 +1,5 @@
|
|||||||
"""setup.py for axolotl"""
|
"""setup.py for axolotl"""
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
@@ -29,15 +30,29 @@ def parse_requirements():
|
|||||||
elif not is_extras and line and line[0] != "#":
|
elif not is_extras and line and line[0] != "#":
|
||||||
# Handle standard packages
|
# Handle standard packages
|
||||||
_install_requires.append(line)
|
_install_requires.append(line)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||||
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
||||||
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
||||||
|
|
||||||
if "Darwin" in platform.system():
|
if "Darwin" in platform.system():
|
||||||
# don't install xformers on MacOS
|
# skip packages not compatible with OSX
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
skip_packages = [
|
||||||
|
"bitsandbytes",
|
||||||
|
"triton",
|
||||||
|
"mamba-ssm",
|
||||||
|
"flash-attn",
|
||||||
|
"xformers",
|
||||||
|
"autoawq",
|
||||||
|
"liger-kernel",
|
||||||
|
]
|
||||||
|
_install_requires = [
|
||||||
|
req
|
||||||
|
for req in _install_requires
|
||||||
|
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
|
||||||
|
]
|
||||||
|
print(
|
||||||
|
_install_requires, [req in skip_packages for req in _install_requires]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# detect the version of torch already installed
|
# detect the version of torch already installed
|
||||||
# and set it so dependencies don't clobber the torch version
|
# and set it so dependencies don't clobber the torch version
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
|||||||
getattr, self.layers_attribute.split("."), self.trainer.model
|
getattr, self.layers_attribute.split("."), self.trainer.model
|
||||||
)
|
)
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
|
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers * 100 / len(layers)}%) every {self.step_interval} steps"
|
||||||
)
|
)
|
||||||
|
|
||||||
def freeze_all_layers(self):
|
def freeze_all_layers(self):
|
||||||
|
|||||||
@@ -270,7 +270,7 @@ def load_sharded_model_quant(
|
|||||||
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
|
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
|
||||||
|
|
||||||
if cfg.local_rank == 0 and verbose:
|
if cfg.local_rank == 0 and verbose:
|
||||||
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
|
print(f"Loaded model weights in {time.time() - start:.3f} seconds")
|
||||||
# cleanup any extra memory usage from parallel loading
|
# cleanup any extra memory usage from parallel loading
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,8 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
|
|||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||||
def snapshot_download_w_retry(*args, **kwargs):
|
def snapshot_download_w_retry(*args, **kwargs):
|
||||||
return snapshot_download(*args, **kwargs)
|
url = snapshot_download(*args, **kwargs)
|
||||||
|
raise f"{args[0]}: {url}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user