various batch of fixes (#1785)

* various batch of fixes

* more tweaks

* fix autoawq requirement for torch flexibility

* simplify conditionals

* multi-node fixes wip

* bump transformers and include 405b qlora+fsdp yaml
This commit is contained in:
Wing Lian
2024-07-28 07:25:54 -04:00
committed by GitHub
parent 22680913f3
commit 94ba93259f
11 changed files with 253 additions and 99 deletions

View File

@@ -0,0 +1,62 @@
base_model: meta-llama/Meta-Llama-3.1-405B
tokenizer_type: AutoTokenizer
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out/qlora-llama3_1-405b
adapter: qlora
sequence_len: 1024
sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|finetune_right_pad_id|>

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.11.1
transformers==4.43.1
transformers==4.43.3
tokenizers==0.19.1
bitsandbytes==0.43.1
accelerate==0.32.0
@@ -32,6 +32,7 @@ fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e59
gradio==3.50.2
tensorboard
python-dotenv==1.0.1
autoawq>=0.2.5
mamba-ssm==1.2.0.post1

View File

@@ -2,6 +2,7 @@
CLI to run training on a model
"""
import logging
import warnings
from pathlib import Path
from typing import Union
@@ -76,8 +77,12 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
if parsed_cli_args.download:
model_name = parsed_cfg.base_model
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
with warnings.catch_warnings():
# there are a bunch of useless UserWarnings about
# "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model"
warnings.simplefilter("ignore")
with init_empty_weights(include_buffers=True):
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
LOG.info(
Fore.GREEN

View File

@@ -0,0 +1,14 @@
"""
Common architecture specific constants
"""
MOE_ARCH_BLOCK = {
"dbrx": "DbrxFFN",
"jamba": "JambaSparseMoeBlock",
"jetmoe": [
"JetMoeMoA",
"JetMoeMoE",
],
"mixtral": "MixtralSparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
}

View File

@@ -8,6 +8,7 @@ import importlib
import importlib.util
import logging
import math
import os
import sys
from abc import abstractmethod
from collections import defaultdict
@@ -28,7 +29,7 @@ from transformers import (
TrainerCallback,
TrainingArguments,
)
from transformers.trainer_utils import seed_worker
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import (
CPOConfig,
@@ -286,7 +287,77 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
)
class AxolotlTrainer(Trainer):
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
"""
args = None # type: AxolotlTrainingArguments
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
class AxolotlTrainer(SchedulerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
@@ -404,68 +475,6 @@ class AxolotlTrainer(Trainer):
return self.optimizer
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:
@@ -830,6 +839,14 @@ class AxolotlTrainer(Trainer):
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, metrics=None):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, metrics=metrics)
class AxolotlMambaTrainer(AxolotlTrainer):
"""
@@ -929,7 +946,7 @@ class ReLoRATrainer(AxolotlTrainer):
return self.lr_scheduler
class AxolotlDPOTrainer(DPOTrainer):
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
@@ -990,7 +1007,7 @@ class AxolotlDPOTrainer(DPOTrainer):
return res
class AxolotlORPOTrainer(ORPOTrainer):
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
@@ -998,7 +1015,7 @@ class AxolotlORPOTrainer(ORPOTrainer):
tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(KTOTrainer):
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
@@ -1006,7 +1023,7 @@ class AxolotlKTOTrainer(KTOTrainer):
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(CPOTrainer):
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
@@ -1750,6 +1767,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha

View File

@@ -62,7 +62,7 @@ def default(
tokenize=False,
)
chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:]
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template(
[rejected],
@@ -71,7 +71,7 @@ def default(
tokenize=False,
)
rejected_strip_index = result["rejected"].find(rejected["content"])
result["rejected"] = result["rejected"][rejected_strip_index:]
result["rejected"] = result["rejected"][rejected_strip_index:].rstrip()
return result

View File

@@ -212,26 +212,23 @@ def train(
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone()
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
trainer.save_model(cfg.output_dir)
# the trainer saved a model.safetensors file in the output directory,
# but it is a proxy model and should be deleted
if os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")):
# but it is most likely a proxy model and if so, should be deleted
maybe_proxy = os.path.exists(os.path.join(cfg.output_dir, "model.safetensors"))
maybe_sharded = os.path.exists(
os.path.join(cfg.output_dir, "model.safetensors.index.json")
)
if maybe_proxy and maybe_sharded:
LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}")
LOG.info("This is a proxy model and should be deleted")
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
try:
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError:
pass
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
# The model name saved is `pytorch_model.bin`
unwrapped_model.save_pretrained(
cfg.output_dir,
is_main_process=trainer.accelerator.is_main_process,
save_function=trainer.accelerator.save,
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
)
elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)

View File

@@ -42,7 +42,7 @@ from axolotl.prompters import (
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.utils import md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
@@ -54,7 +54,7 @@ LOG = logging.getLogger("axolotl")
def prepare_dataset(cfg, tokenizer):
prompters = []
if not cfg.pretraining_dataset:
with zero_first(is_main_process()):
with zero_first(is_local_main_process()):
if cfg.test_datasets:
train_dataset, _, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
@@ -170,6 +170,7 @@ def load_tokenized_prepared_datasets(
# pylint: disable=duplicate-code
if dataset:
# This is for the case where we already loaded a pretokenized dataset from the hub
...
elif (
cfg.dataset_prepared_path
@@ -198,6 +199,8 @@ def load_tokenized_prepared_datasets(
def for_d_in_datasets(dataset_configs):
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
# load_dataset doesn't properly handle multiple named configurations
# at the same time for a given dataset
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
else:
@@ -208,6 +211,8 @@ def load_tokenized_prepared_datasets(
ds: Optional[Union[Dataset, DatasetDict]] = None
ds_from_hub = False
try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
load_dataset(
config_dataset.path,
name=config_dataset.name,

View File

@@ -44,6 +44,10 @@ def is_main_process():
return dist.get_rank() == 0
def is_local_main_process():
return PartialState().is_main_process
def get_world_size():
return int(os.getenv("WORLD_SIZE", "1"))

View File

@@ -29,6 +29,7 @@ from transformers import ( # noqa: F401
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
PreTrainedModel,
@@ -36,6 +37,7 @@ from transformers import ( # noqa: F401
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -510,7 +512,25 @@ def load_model(
model_kwargs["quantization_config"] = GPTQConfig(
**model_config.quantization_config
)
if cfg.adapter == "qlora" and cfg.load_in_4bit:
if (
cfg.adapter in ["qlora", "lora"]
and hasattr(model_config, "quantization_config")
and model_config.quantization_config["quant_method"]
in ["gptq", "awq", "bitsandbytes"]
):
if model_config.quantization_config["quant_method"] == "gptq":
model_kwargs["quantization_config"] = GPTQConfig(
**model_config.quantization_config
)
elif model_config.quantization_config["quant_method"] == "awq":
model_kwargs["quantization_config"] = AwqConfig(
**model_config.quantization_config
)
elif model_config.quantization_config["quant_method"] == "bitsandbytes":
model_kwargs["quantization_config"] = BitsAndBytesConfig(
**model_config.quantization_config
)
elif cfg.adapter == "qlora" and cfg.load_in_4bit:
bnb_config = {
"load_in_4bit": True,
"llm_int8_threshold": 6.0,
@@ -785,12 +805,14 @@ def load_model(
set_z3_leaf_modules,
)
if cfg.model_config_type == "mixtral":
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
set_z3_leaf_modules(model, [moe_block])
elif cfg.model_config_type == "dbrx":
moe_block = get_module_class_from_name(model, "DbrxFFN")
set_z3_leaf_modules(model, [moe_block])
if cfg.model_config_type in MOE_ARCH_BLOCK:
set_z3_leaf_modules(
model,
[
get_module_class_from_name(model, module_name)
for module_name in MOE_ARCH_BLOCK[cfg.model_config_type]
],
)
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
@@ -804,6 +826,9 @@ def load_model(
# make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True
if is_deepspeed_zero3_enabled():
skip_prepare_model_for_kbit_training = True
if cfg.adapter in ["lora", "qlora"]:
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable(
@@ -838,6 +863,9 @@ def load_model(
else:
model, lora_config = load_adapter(model, cfg, cfg.adapter)
if is_deepspeed_zero3_enabled():
skip_move_to_device = True
if (
cfg.ddp
and not load_in_8bit

View File

@@ -1,4 +1,5 @@
"""Module containing the Trainer class and related functions"""
import json
import math
import os
import random
@@ -389,6 +390,19 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
return total_num_steps
def setup_deepspeed_env(cfg, stage=None):
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if cfg.bf16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
if stage:
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
def setup_fsdp_envs(cfg):
os.environ["ACCELERATE_USE_FSDP"] = "true"
if cfg.fsdp_config.fsdp_activation_checkpointing:
@@ -415,8 +429,14 @@ def prepare_optim_env(cfg):
if cfg.fsdp:
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
stage = None
# check if the cfg.deepspeed is a file
if os.path.isfile(cfg.deepspeed):
# parse with json
with open(cfg.deepspeed, "r", encoding="utf-8") as fin:
deepspeed_config = json.load(fin)
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
setup_deepspeed_env(cfg, stage=stage)
if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"