remove un-needed code, add validation

This commit is contained in:
Wing Lian
2023-05-24 22:47:33 -04:00
parent 7e81ca720b
commit 1f5d83ea72
2 changed files with 3 additions and 15 deletions

View File

@@ -14,6 +14,7 @@ from attrdict import AttrDefault
# add src to the pythonpath so we don't need to pip install this
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.validation import validate_config
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
@@ -158,6 +159,8 @@ def train(
cfg.fp16 = True
cfg.bf16 = False
validate_config(cfg)
# Load the model and tokenizer
logging.info("loading model, tokenizer, and peft_config...")
model, tokenizer, peft_config = load_model(

View File

@@ -204,21 +204,6 @@ def load_model(
**model_kwargs,
)
"""### Post-processing on the model
Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons.
"""
# if cfg.adapter == "qlora":
# for param in model.parameters():
# param.requires_grad = False # freeze the model - train adapters later
# if param.ndim == 1:
# # cast the small parameters (e.g. layernorm) to fp32 for stability
# param.data = param.data.to(torch.float32)
# class CastOutputToFloat(nn.Linear):
# def forward(self, x):
# return super().forward(x).to(torch.float32)
#
# model.lm_head = CastOutputToFloat(model.lm_head.in_features, model.lm_head.out_features, model.lm_head.bias)
if not tokenizer:
try:
if is_llama_derived_model and "LlamaTokenizer" in globals():