hotfix for capabilities loading (#1331)
This commit is contained in:
@@ -30,7 +30,6 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
|||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
from axolotl.utils.config import (
|
from axolotl.utils.config import (
|
||||||
GPUCapabilities,
|
|
||||||
normalize_cfg_datasets,
|
normalize_cfg_datasets,
|
||||||
normalize_config,
|
normalize_config,
|
||||||
validate_config,
|
validate_config,
|
||||||
@@ -350,14 +349,15 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|||||||
except: # pylint: disable=bare-except # noqa: E722
|
except: # pylint: disable=bare-except # noqa: E722
|
||||||
gpu_version = None
|
gpu_version = None
|
||||||
|
|
||||||
capabilities = GPUCapabilities(
|
cfg = validate_config(
|
||||||
bf16=is_torch_bf16_gpu_available(),
|
cfg,
|
||||||
n_gpu=os.environ.get("WORLD_SIZE", 1),
|
capabilities={
|
||||||
compute_capability=gpu_version,
|
"bf16": is_torch_bf16_gpu_available(),
|
||||||
|
"n_gpu": os.environ.get("WORLD_SIZE", 1),
|
||||||
|
"compute_capability": gpu_version,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg = validate_config(cfg, capabilities=capabilities)
|
|
||||||
|
|
||||||
prepare_optim_env(cfg)
|
prepare_optim_env(cfg)
|
||||||
|
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from axolotl.utils.config.models.input.v0_4_1 import (
|
|||||||
AxolotlConfigWCapabilities,
|
AxolotlConfigWCapabilities,
|
||||||
AxolotlInputConfig,
|
AxolotlInputConfig,
|
||||||
)
|
)
|
||||||
from axolotl.utils.config.models.internals import GPUCapabilities
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model_config
|
from axolotl.utils.models import load_model_config
|
||||||
|
|
||||||
@@ -197,7 +196,7 @@ def normalize_cfg_datasets(cfg):
|
|||||||
cfg.datasets[idx].conversation = "chatml"
|
cfg.datasets[idx].conversation = "chatml"
|
||||||
|
|
||||||
|
|
||||||
def validate_config(cfg: DictDefault, capabilities: Optional[GPUCapabilities] = None):
|
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||||
if capabilities:
|
if capabilities:
|
||||||
return DictDefault(
|
return DictDefault(
|
||||||
dict(AxolotlConfigWCapabilities(**cfg.to_dict(), capabilities=capabilities))
|
dict(AxolotlConfigWCapabilities(**cfg.to_dict(), capabilities=capabilities))
|
||||||
|
|||||||
Reference in New Issue
Block a user