move is_llama_derived_model into normalize_config (#524)
This commit is contained in:
committed by
GitHub
parent
09f154397e
commit
44454ae4c4
@@ -24,7 +24,7 @@ from axolotl.utils.config import normalize_config, validate_config
|
|||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
from axolotl.utils.models import load_model_config, load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||||
|
|
||||||
@@ -216,15 +216,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|||||||
else:
|
else:
|
||||||
cfg[k] = kwargs[k]
|
cfg[k] = kwargs[k]
|
||||||
|
|
||||||
model_config = load_model_config(cfg)
|
|
||||||
|
|
||||||
# figure out if the model is llama
|
|
||||||
cfg.is_llama_derived_model = (
|
|
||||||
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
|
||||||
or cfg.is_llama_derived_model
|
|
||||||
or "llama" in cfg.base_model
|
|
||||||
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
|
||||||
)
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
|
from axolotl.utils.models import load_model_config
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -69,6 +70,16 @@ def normalize_config(cfg):
|
|||||||
else:
|
else:
|
||||||
cfg.torch_dtype = torch.float32
|
cfg.torch_dtype = torch.float32
|
||||||
|
|
||||||
|
model_config = load_model_config(cfg)
|
||||||
|
|
||||||
|
# figure out if the model is llama
|
||||||
|
cfg.is_llama_derived_model = (
|
||||||
|
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
||||||
|
or cfg.is_llama_derived_model
|
||||||
|
or "llama" in cfg.base_model
|
||||||
|
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
||||||
|
)
|
||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user