hotfix for capabilities loading (#1331)

This commit is contained in:
Wing Lian
2024-02-26 14:24:28 -05:00
committed by GitHub
parent d75653407c
commit 7de912e097
2 changed files with 8 additions and 9 deletions

View File

@@ -30,7 +30,6 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta from axolotl.train import TrainDatasetMeta
from axolotl.utils.config import ( from axolotl.utils.config import (
GPUCapabilities,
normalize_cfg_datasets, normalize_cfg_datasets,
normalize_config, normalize_config,
validate_config, validate_config,
@@ -350,14 +349,15 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
except: # pylint: disable=bare-except # noqa: E722 except: # pylint: disable=bare-except # noqa: E722
gpu_version = None gpu_version = None
capabilities = GPUCapabilities( cfg = validate_config(
bf16=is_torch_bf16_gpu_available(), cfg,
n_gpu=os.environ.get("WORLD_SIZE", 1), capabilities={
compute_capability=gpu_version, "bf16": is_torch_bf16_gpu_available(),
"n_gpu": os.environ.get("WORLD_SIZE", 1),
"compute_capability": gpu_version,
},
) )
cfg = validate_config(cfg, capabilities=capabilities)
prepare_optim_env(cfg) prepare_optim_env(cfg)
normalize_config(cfg) normalize_config(cfg)

View File

@@ -13,7 +13,6 @@ from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlConfigWCapabilities, AxolotlConfigWCapabilities,
AxolotlInputConfig, AxolotlInputConfig,
) )
from axolotl.utils.config.models.internals import GPUCapabilities
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config from axolotl.utils.models import load_model_config
@@ -197,7 +196,7 @@ def normalize_cfg_datasets(cfg):
cfg.datasets[idx].conversation = "chatml" cfg.datasets[idx].conversation = "chatml"
def validate_config(cfg: DictDefault, capabilities: Optional[GPUCapabilities] = None): def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
if capabilities: if capabilities:
return DictDefault( return DictDefault(
dict(AxolotlConfigWCapabilities(**cfg.to_dict(), capabilities=capabilities)) dict(AxolotlConfigWCapabilities(**cfg.to_dict(), capabilities=capabilities))