Print versions (#1496)
* print out dependency versions for easier debugging * improve readability
This commit is contained in:
@@ -24,6 +24,7 @@ from huggingface_hub import HfApi
|
|||||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
from transformers.utils.import_utils import _is_package_available
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
@@ -62,6 +63,20 @@ def print_axolotl_text_art(suffix=None):
|
|||||||
if is_main_process():
|
if is_main_process():
|
||||||
print(ascii_art)
|
print(ascii_art)
|
||||||
|
|
||||||
|
print_dep_versions()
|
||||||
|
|
||||||
|
|
||||||
|
def print_dep_versions():
|
||||||
|
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
|
||||||
|
max_len = max(len(pkg) for pkg in packages)
|
||||||
|
if is_main_process():
|
||||||
|
print("*" * 40)
|
||||||
|
print("**** Axolotl Dependency Versions *****")
|
||||||
|
for pkg in packages:
|
||||||
|
version = _is_package_available(pkg, return_version=True)
|
||||||
|
print(f"{pkg: >{max_len}}: {version[1]: <15}")
|
||||||
|
print("*" * 40)
|
||||||
|
|
||||||
|
|
||||||
def check_remote_config(config: Union[str, Path]):
|
def check_remote_config(config: Union[str, Path]):
|
||||||
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
|
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
|
||||||
|
|||||||
Reference in New Issue
Block a user