Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
effb281b24 wip for multipack pretraining 2023-11-25 17:12:20 -05:00
47 changed files with 107 additions and 230 deletions

View File

@@ -612,12 +612,6 @@ eval_sample_packing:
sample_packing_eff_est:
total_num_tokens:
# Passed through to transformers when loading the model when launched without accelerate
# Use `sequential` when training w/ model parallelism to limit memory
device_map:
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
max_memory:
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora
# If you already have a lora model trained that you want to load, put that here.
@@ -665,8 +659,7 @@ wandb_mode: # "offline" to save run metadata locally and not sync to the server,
wandb_project: # Your wandb project name
wandb_entity: # A wandb Team name if using a Team
wandb_watch:
wandb_name: # Set the name of your wandb run
wandb_run_id: # Set the ID of your wandb run
wandb_run_id: # Set the name of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# Where to save the full-finetuned model to
@@ -701,9 +694,6 @@ max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
# Save model as safetensors (require safetensors package)
save_safetensors:
@@ -962,7 +952,7 @@ wandb_mode:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
```

View File

@@ -24,6 +24,16 @@
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",

View File

@@ -28,6 +28,16 @@
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",

View File

@@ -32,6 +32,16 @@
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",

View File

@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: btlm-out

View File

@@ -24,7 +24,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
batch_size: 4

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./falcon-7b
batch_size: 2

View File

@@ -40,7 +40,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out

View File

@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./falcon-7b
batch_size: 2

View File

@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
gradient_accumulation_steps: 2

View File

@@ -19,7 +19,7 @@ lora_fan_in_fan_out: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./jeopardy-bot-7b
gradient_accumulation_steps: 1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1

View File

@@ -32,7 +32,7 @@ lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./model-out
gradient_accumulation_steps: 1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -35,7 +35,7 @@ relora_cpu_offload: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -21,7 +21,7 @@ pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -38,7 +38,7 @@ lora_target_modules:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -62,9 +62,6 @@ logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
eval_steps: 0.05
eval_table_size:

View File

@@ -21,7 +21,7 @@ lora_fan_in_fan_out: false
wandb_project: mpt-alpaca-7b
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./mpt-alpaca-7b
gradient_accumulation_steps: 1

View File

@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./openllama-out
gradient_accumulation_steps: 1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-out
gradient_accumulation_steps: 1

View File

@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
gradient_accumulation_steps: 1

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1

View File

@@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./pythia-12b
gradient_accumulation_steps: 1

View File

@@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-alpaca-pythia
gradient_accumulation_steps: 1

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -53,7 +53,7 @@ resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -53,7 +53,7 @@ resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05

View File

@@ -22,7 +22,7 @@ lora_fan_in_fan_out: false
wandb_project: redpajama-alpaca-3b
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./redpajama-alpaca-3b
batch_size: 4

View File

@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
wandb_project: lora-replit
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-replit
batch_size: 8

View File

@@ -38,7 +38,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out

View File

@@ -29,7 +29,6 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -72,7 +71,7 @@ def do_merge_lora(
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=cfg.torch_dtype)
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
@@ -297,8 +296,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
validate_config(cfg)
prepare_optim_env(cfg)
normalize_config(cfg)
setup_wandb_env_vars(cfg)

View File

@@ -25,7 +25,6 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
@@ -431,9 +430,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -647,7 +643,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
training_arguments_kwargs["run_name"] = (
self.cfg.wandb_name if self.cfg.use_wandb else None
self.cfg.wandb_run_id if self.cfg.use_wandb else None
)
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"

View File

@@ -124,36 +124,6 @@ class GPUStatsCallback(
return control
class LossWatchDogCallback(TrainerCallback):
"""Callback to track loss and stop training if loss is too high"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
self.violations = 0
self.threshold = cfg.loss_watchdog_threshold
self.patience = cfg.loss_watchdog_patience or 3
def on_step_end(
self,
_args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
):
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
if state.log_history[-1]["loss"] > self.threshold:
self.violations += 1
if self.violations >= self.patience:
LOG.warning(
"Loss is too high, stopping training (loss_watchdog_threshold)"
)
control.should_training_stop = True
else:
self.violations = 0
return control
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [

View File

@@ -27,7 +27,7 @@ def choose_device(cfg):
cfg.device = get_device()
if cfg.world_size == 1:
cfg.device_map = cfg.device_map or "auto"
cfg.device_map = "auto"
else:
if cfg.device.startswith("cuda"):
cfg.device_map = {"": torch.cuda.current_device()}
@@ -397,13 +397,6 @@ def validate_config(cfg):
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
)
if cfg.wandb_run_id and not cfg.wandb_name:
cfg.wandb_name = cfg.wandb_run_id
LOG.warning(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -698,6 +698,24 @@ def get_dataset_wrapper(
return dataset_wrapper, dataset_prompter
def encode_packed_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
):
# tokenize all the examples
# rows get split with stride (overlap)
res = tokenizer(
examples,
truncation=True,
max_length=max_tokens,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
# convert to a dataset.from_list
# use a dataloader and multipack batch sampler to pack the data
pass
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
) -> Dict[str, List]:
@@ -813,6 +831,7 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
dataset = dataset.map(
encode,
batched=True,
batch_size=10_000,
input_columns="text",
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column

View File

@@ -28,27 +28,6 @@ from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl")
def check_model_config(cfg: DictDefault, model_config: AutoConfig):
quant_config_exists = hasattr(model_config, "quantization_config")
quant_config_method_is_gptq = (
quant_config_exists
and "quant_method" in model_config.quantization_config
and model_config.quantization_config["quant_method"] == "gptq"
)
if cfg.gptq and not quant_config_method_is_gptq:
raise ValueError(
"model_config.quantization_config is not set or quant_method is not set to gptq. "
"Please make sure to point to a GPTQ model."
)
if not cfg.gptq and quant_config_exists:
raise ValueError(
"model_config.quantization_config is set but `gptq` flag is not. "
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
)
def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code = cfg.trust_remote_code is True
@@ -59,8 +38,6 @@ def load_model_config(cfg):
for key, val in cfg.model_config.items():
setattr(model_config, key, val)
check_model_config(cfg, model_config)
return model_config
@@ -239,7 +216,6 @@ def load_model(
model_kwargs = {}
model_kwargs["device_map"] = cfg.device_map
model_kwargs["max_memory"] = cfg.max_memory
model_kwargs["torch_dtype"] = cfg.torch_dtype
if cfg.model_revision:
@@ -436,22 +412,15 @@ def load_model(
module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
skip_prepare_model_for_kbit_training = True
if (cfg.adapter == "lora" and load_in_8bit) or (
cfg.adapter == "qlora" and cfg.load_in_4bit
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable()
if not skip_prepare_model_for_kbit_training:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to

View File

@@ -267,14 +267,12 @@ def setup_fsdp_envs(cfg):
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
def prepare_optim_env(cfg):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.fsdp:
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset

View File

@@ -2,20 +2,20 @@
import os
from axolotl.utils.dict import DictDefault
def setup_wandb_env_vars(cfg: DictDefault):
for key in cfg.keys():
if key.startswith("wandb_"):
value = cfg.get(key, "")
if value and isinstance(value, str) and len(value) > 0:
os.environ[key.upper()] = value
# Enable wandb if project name is present
if cfg.wandb_project and len(cfg.wandb_project) > 0:
def setup_wandb_env_vars(cfg):
if cfg.wandb_mode and cfg.wandb_mode == "offline":
os.environ["WANDB_MODE"] = cfg.wandb_mode
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project
cfg.use_wandb = True
os.environ.pop("WANDB_DISABLED", None) # Remove if present
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
else:
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -1,7 +1,6 @@
"""Module for testing the validation module"""
import logging
import os
import unittest
from typing import Optional
@@ -9,7 +8,6 @@ import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.wandb_ import setup_wandb_env_vars
class ValidationTest(unittest.TestCase):
@@ -681,83 +679,3 @@ class ValidationTest(unittest.TestCase):
)
validate_config(cfg)
class ValidationWandbTest(ValidationTest):
"""
Validation test for wandb
"""
def test_wandb_set_run_id_to_name(self):
cfg = DictDefault(
{
"wandb_run_id": "foo",
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
in record.message
for record in self._caplog.records
)
assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
cfg = DictDefault(
{
"wandb_name": "foo",
}
)
validate_config(cfg)
assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
def test_wandb_sets_env(self):
cfg = DictDefault(
{
"wandb_project": "foo",
"wandb_name": "bar",
"wandb_run_id": "bat",
"wandb_entity": "baz",
"wandb_mode": "online",
"wandb_watch": "false",
"wandb_log_model": "checkpoint",
}
)
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_PROJECT", "") == "foo"
assert os.environ.get("WANDB_NAME", "") == "bar"
assert os.environ.get("WANDB_RUN_ID", "") == "bat"
assert os.environ.get("WANDB_ENTITY", "") == "baz"
assert os.environ.get("WANDB_MODE", "") == "online"
assert os.environ.get("WANDB_WATCH", "") == "false"
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
assert os.environ.get("WANDB_DISABLED", "") != "true"
def test_wandb_set_disabled(self):
cfg = DictDefault({})
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_DISABLED", "") == "true"
cfg = DictDefault(
{
"wandb_project": "foo",
}
)
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_DISABLED", "") != "true"