more logging, wandb fixes

This commit is contained in:
Wing Lian
2023-04-15 13:37:17 -04:00
parent 2df63ef815
commit 05fffb53b4
5 changed files with 81 additions and 14 deletions

View File

@@ -23,7 +23,7 @@ lora_target_modules:
lora_fan_in_fan_out: false
wandb_project: pythia-1.4b-lora
wandb_watch:
wandb_run_name:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./lora-alpaca
batch_size: 32

View File

@@ -25,7 +25,7 @@ lora_target_modules:
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: llama-65b-lora
wandb_watch:
wandb_run_name:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./lora-llama-alpaca
batch_size: 128

View File

@@ -25,7 +25,7 @@ lora_target_modules:
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: pythia-1.4b-lora
wandb_watch:
wandb_run_name:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./lora-alpaca
batch_size: 48

56
ds_config.json Normal file
View File

@@ -0,0 +1,56 @@
{
"bf16": {
"enabled": "auto",
},
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 5,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -1,3 +1,4 @@
import logging
import math
import os
import random
@@ -37,6 +38,9 @@ from axolotl.prompt_tokenizers import (
)
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
logger = logging.getLogger(__name__)
DEFAULT_DATASET_PREPARED_PATH = "data/last_run"
def setup_wandb_env_vars(cfg):
if len(cfg.wandb_project) > 0:
@@ -46,6 +50,8 @@ def setup_wandb_env_vars(cfg):
os.environ["WANDB_WATCH"] = cfg.wandb_watch
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
@@ -164,8 +170,8 @@ def check_dataset_labels(dataset, tokenizer):
)
colored_tokens.append(colored_token)
print(" ".join(colored_tokens))
print("\n\n\n")
logger.info(" ".join(colored_tokens))
logger.info("\n\n\n")
def do_inference(cfg, model, tokenizer):
@@ -247,7 +253,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
ddp_find_unused_parameters=False if cfg.ddp else None,
group_by_length=cfg.group_by_length,
report_to="wandb" if cfg.use_wandb else None,
run_name=cfg.wandb_run_name if cfg.use_wandb else None,
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
**training_arguments_kwargs,
)
@@ -341,9 +347,9 @@ def train(
return
if cfg.dataset_prepared_path and any(Path(cfg.dataset_prepared_path).glob("*")):
print("Loading prepared dataset from disk...")
dataset = load_from_disk(cfg.datasets)
print("Prepared dataset loaded from disk...")
logger.info("Loading prepared dataset from disk...")
dataset = load_from_disk(cfg.dataset_prepared_path)
logger.info("Prepared dataset loaded from disk...")
else:
datasets = []
for d in cfg.datasets:
@@ -376,11 +382,12 @@ def train(
[_ for _ in constant_len_dataset]
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
print("Saving prepared dataset to disk...")
if cfg.dataset_prepared_path:
dataset.save_to_disk(cfg.dataset_prepared_path)
else:
dataset.save_to_disk("data/last_run")
if cfg.local_rank == 0:
logger.info("Saving prepared dataset to disk...")
if cfg.dataset_prepared_path:
dataset.save_to_disk(cfg.dataset_prepared_path)
else:
dataset.save_to_disk(DEFAULT_DATASET_PREPARED_PATH)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
@@ -396,9 +403,11 @@ def train(
model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
logger.info("Compiling torch model")
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect
logger.info(f"Pre-saving adapter config to {cfg.output_dir}")
lora_config.save_pretrained(cfg.output_dir)
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
@@ -407,9 +416,11 @@ def train(
lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
)
logger.info("Starting trainer...")
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
logger.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
model.save_pretrained(cfg.output_dir)