qlora-fsdp ram efficient loading with hf trainer (#1791)

* fix 405b with lower cpu ram requirements

* make sure to use doouble quant and only skip output embeddings

* set model attributes

* more fixes for sharded fsdp loading

* update the base model in example to use pre-quantized nf4-bf16 weights

* upstream fixes  for qlora+fsdp
This commit is contained in:
Wing Lian
2024-07-30 19:21:38 -04:00
committed by GitHub
parent dbf8fb549e
commit 3ebf22464b
10 changed files with 52 additions and 14 deletions

View File

@@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub" ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1" ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub" ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1" ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3.1-405B base_model: hugging-quants/Meta-Llama-3.1-405B-BNB-NF4-BF16
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
load_in_4bit: true load_in_4bit: true
@@ -10,10 +10,11 @@ datasets:
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./outputs/out/qlora-llama3_1-405b output_dir: ./outputs/out/qlora-llama3_1-405b
save_safetensors: true
adapter: qlora adapter: qlora
sequence_len: 1024 sequence_len: 2048
sample_packing: true sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
@@ -25,7 +26,7 @@ lora_target_linear: true
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 2
optimizer: adamw_torch optimizer: adamw_torch
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.00001 learning_rate: 0.00001

View File

@@ -1,9 +1,9 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.11.1 peft==0.11.1
transformers==4.43.3 transformers @ git+https://github.com/huggingface/transformers.git@026a173a64372e9602a16523b8fae9de4b0ff428
tokenizers==0.19.1 tokenizers==0.19.1
bitsandbytes==0.43.1 bitsandbytes==0.43.3
accelerate==0.32.0 accelerate==0.32.0
deepspeed==0.14.4 deepspeed==0.14.4
pydantic==2.6.3 pydantic==2.6.3

View File

@@ -40,7 +40,7 @@ from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import load_tokenizer from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_optim_env from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -382,6 +382,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
prepare_optim_env(cfg) prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg) normalize_config(cfg)
normalize_cfg_datasets(cfg) normalize_cfg_datasets(cfg)

View File

@@ -1243,7 +1243,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.fsdp: if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config: if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config) training_arguments_kwargs["fsdp_config"] = {
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
}
if self.cfg.adapter == "qlora": if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True training_arguments_kwargs["qlora"] = True

View File

@@ -235,6 +235,12 @@ class LoraConfig(BaseModel):
peft_use_rslora: Optional[bool] = None peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None peft_layer_replication: Optional[List[Tuple[int, int]]] = None
qlora_sharded_model_loading: Optional[bool] = Field(
default=False,
metadata={
"help": "load qlora model in sharded format for FSDP using answer.ai technique."
},
)
lora_on_cpu: Optional[bool] = None lora_on_cpu: Optional[bool] = None
gptq: Optional[bool] = None gptq: Optional[bool] = None
bnb_config_kwargs: Optional[Dict[str, Any]] = None bnb_config_kwargs: Optional[Dict[str, Any]] = None
@@ -939,6 +945,8 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_eval_packing(cls, data): def check_eval_packing(cls, data):
# TODO also should check test_datasets and val_set_size as we can skip
# if there are no eval datasets/splits
if ( if (
data.get("sample_packing") data.get("sample_packing")
and data.get("eval_table_size") and data.get("eval_table_size")

View File

@@ -13,6 +13,7 @@ from fastcore.parallel import parallel
from torch import Tensor, nn from torch import Tensor, nn
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from transformers.quantizers import AutoHfQuantizer
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
@@ -173,6 +174,7 @@ def load_sharded_model_quant(
low_memory=True, low_memory=True,
verbose=False, verbose=False,
loading_workers=2, loading_workers=2,
quantization_config=None,
): ):
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config( model = AutoModelForCausalLM.from_config(
@@ -186,15 +188,26 @@ def load_sharded_model_quant(
compute_dtype=compute_dtype, compute_dtype=compute_dtype,
quant_type="nf4", quant_type="nf4",
quant_storage=quant_storage, quant_storage=quant_storage,
compress_statistics=True, # bnb_4bit_use_double_quant
skip_modules=[
"lm_head",
"embed_out",
],
) )
else: else:
# this is the more common case with HF transformers # this is the more common case with HF transformers
# TODO can we detect the model arch and dynamically set skip_modules
model.model = _replace_linear( model.model = _replace_linear(
model.model, model.model,
Linear4bit, Linear4bit,
compute_dtype=compute_dtype, compute_dtype=compute_dtype,
quant_type="nf4", quant_type="nf4",
quant_storage=quant_storage, quant_storage=quant_storage,
compress_statistics=True, # bnb_4bit_use_double_quant
skip_modules=[
"lm_head",
"embed_out",
],
) )
model.is_loaded_in_4bit = True model.is_loaded_in_4bit = True
@@ -251,6 +264,11 @@ def load_sharded_model_quant(
quant_method=quant_method, quant_method=quant_method,
) )
# these attributes are needed to inform transformers/peft of the quantization
model.is_quantized = True
model.quantization_method = "bitsandbytes"
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
if cfg.local_rank == 0 and verbose: if cfg.local_rank == 0 and verbose:
print(f"Loaded model weights in {time.time()-start:.3f} seconds") print(f"Loaded model weights in {time.time()-start:.3f} seconds")
# cleanup any extra memory usage from parallel loading # cleanup any extra memory usage from parallel loading

View File

@@ -624,14 +624,21 @@ def load_model(
elif ( elif (
qlora_fsdp qlora_fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.model_config_type == "dbrx" and (cfg.model_config_type == "dbrx" or cfg.qlora_sharded_model_loading)
): ):
quant_storage = cfg.torch_dtype quant_storage = cfg.torch_dtype
quantization_config = hasattr(
model_config, "quantization_config"
) and getattr(model_config, "quantization_config")
quantization_config = (
quantization_config or model_kwargs["quantization_config"]
)
model = load_sharded_model_quant( model = load_sharded_model_quant(
base_model, base_model,
model_config, model_config,
cfg, cfg,
quant_storage=quant_storage, quant_storage=quant_storage,
quantization_config=quantization_config,
) )
skip_move_to_device = True skip_move_to_device = True
elif ( elif (

View File

@@ -393,10 +393,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
def setup_deepspeed_env(cfg, stage=None): def setup_deepspeed_env(cfg, stage=None):
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if cfg.bf16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
if stage: if stage:
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage) os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3: if stage == 3:
@@ -444,6 +440,12 @@ def prepare_optim_env(cfg):
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
def prepare_opinionated_env(cfg):
if cfg.qlora_sharded_model_loading:
# model loading is forked after the tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]: if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)