Attempt to run multigpu in PR CI for now to ensure it works (#1815) [skip ci]
* Attempt to run multigpu in PR CI for now to ensure it works * fix yaml file * forgot to include multigpu tests * fix call to cicd.multigpu * dump dictdefault to dict for yaml conversion * use to_dict instead of casting * 16bit-lora w flash attention, 8bit lora seems problematic * add llama fsdp test * more tests * Add test for qlora + fsdp with prequant * limit accelerate to 2 processes and disable broken qlora+fsdp+bnb test * move multigpu tests to biweekly
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
"""
|
||||
modal application to run axolotl gpu tests in Modal
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
@@ -21,9 +23,9 @@ df_template = template_env.get_template("Dockerfile.jinja")
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.0.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.10-cu118-2.0.1"),
|
||||
"CUDA": os.environ.get("CUDA", "118"),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.3.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.3.1"),
|
||||
"CUDA": os.environ.get("CUDA", "121"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user