qlora-fsdp ram efficient loading with hf trainer (#1791)
* fix 405b with lower cpu ram requirements * make sure to use doouble quant and only skip output embeddings * set model attributes * more fixes for sharded fsdp loading * update the base model in example to use pre-quantized nf4-bf16 weights * upstream fixes for qlora+fsdp
This commit is contained in:
@@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG
|
||||
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG
|
||||
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
base_model: meta-llama/Meta-Llama-3.1-405B
|
||||
base_model: hugging-quants/Meta-Llama-3.1-405B-BNB-NF4-BF16
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_4bit: true
|
||||
@@ -10,10 +10,11 @@ datasets:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out/qlora-llama3_1-405b
|
||||
save_safetensors: true
|
||||
|
||||
adapter: qlora
|
||||
|
||||
sequence_len: 1024
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
@@ -25,7 +26,7 @@ lora_target_linear: true
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
num_epochs: 2
|
||||
optimizer: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.00001
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft==0.11.1
|
||||
transformers==4.43.3
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@026a173a64372e9602a16523b8fae9de4b0ff428
|
||||
tokenizers==0.19.1
|
||||
bitsandbytes==0.43.1
|
||||
bitsandbytes==0.43.3
|
||||
accelerate==0.32.0
|
||||
deepspeed==0.14.4
|
||||
pydantic==2.6.3
|
||||
|
||||
@@ -40,7 +40,7 @@ from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
@@ -382,6 +382,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
|
||||
prepare_optim_env(cfg)
|
||||
|
||||
prepare_opinionated_env(cfg)
|
||||
|
||||
normalize_config(cfg)
|
||||
|
||||
normalize_cfg_datasets(cfg)
|
||||
|
||||
@@ -1243,7 +1243,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.fsdp:
|
||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||
if self.cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
|
||||
training_arguments_kwargs["fsdp_config"] = {
|
||||
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
|
||||
}
|
||||
|
||||
if self.cfg.adapter == "qlora":
|
||||
training_arguments_kwargs["qlora"] = True
|
||||
|
||||
@@ -235,6 +235,12 @@ class LoraConfig(BaseModel):
|
||||
peft_use_rslora: Optional[bool] = None
|
||||
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
||||
|
||||
qlora_sharded_model_loading: Optional[bool] = Field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "load qlora model in sharded format for FSDP using answer.ai technique."
|
||||
},
|
||||
)
|
||||
lora_on_cpu: Optional[bool] = None
|
||||
gptq: Optional[bool] = None
|
||||
bnb_config_kwargs: Optional[Dict[str, Any]] = None
|
||||
@@ -939,6 +945,8 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_eval_packing(cls, data):
|
||||
# TODO also should check test_datasets and val_set_size as we can skip
|
||||
# if there are no eval datasets/splits
|
||||
if (
|
||||
data.get("sample_packing")
|
||||
and data.get("eval_table_size")
|
||||
|
||||
@@ -13,6 +13,7 @@ from fastcore.parallel import parallel
|
||||
from torch import Tensor, nn
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.quantizers import AutoHfQuantizer
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
|
||||
|
||||
|
||||
@@ -173,6 +174,7 @@ def load_sharded_model_quant(
|
||||
low_memory=True,
|
||||
verbose=False,
|
||||
loading_workers=2,
|
||||
quantization_config=None,
|
||||
):
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
@@ -186,15 +188,26 @@ def load_sharded_model_quant(
|
||||
compute_dtype=compute_dtype,
|
||||
quant_type="nf4",
|
||||
quant_storage=quant_storage,
|
||||
compress_statistics=True, # bnb_4bit_use_double_quant
|
||||
skip_modules=[
|
||||
"lm_head",
|
||||
"embed_out",
|
||||
],
|
||||
)
|
||||
else:
|
||||
# this is the more common case with HF transformers
|
||||
# TODO can we detect the model arch and dynamically set skip_modules
|
||||
model.model = _replace_linear(
|
||||
model.model,
|
||||
Linear4bit,
|
||||
compute_dtype=compute_dtype,
|
||||
quant_type="nf4",
|
||||
quant_storage=quant_storage,
|
||||
compress_statistics=True, # bnb_4bit_use_double_quant
|
||||
skip_modules=[
|
||||
"lm_head",
|
||||
"embed_out",
|
||||
],
|
||||
)
|
||||
model.is_loaded_in_4bit = True
|
||||
|
||||
@@ -251,6 +264,11 @@ def load_sharded_model_quant(
|
||||
quant_method=quant_method,
|
||||
)
|
||||
|
||||
# these attributes are needed to inform transformers/peft of the quantization
|
||||
model.is_quantized = True
|
||||
model.quantization_method = "bitsandbytes"
|
||||
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
|
||||
|
||||
if cfg.local_rank == 0 and verbose:
|
||||
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
|
||||
# cleanup any extra memory usage from parallel loading
|
||||
|
||||
@@ -624,14 +624,21 @@ def load_model(
|
||||
elif (
|
||||
qlora_fsdp
|
||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
and cfg.model_config_type == "dbrx"
|
||||
and (cfg.model_config_type == "dbrx" or cfg.qlora_sharded_model_loading)
|
||||
):
|
||||
quant_storage = cfg.torch_dtype
|
||||
quantization_config = hasattr(
|
||||
model_config, "quantization_config"
|
||||
) and getattr(model_config, "quantization_config")
|
||||
quantization_config = (
|
||||
quantization_config or model_kwargs["quantization_config"]
|
||||
)
|
||||
model = load_sharded_model_quant(
|
||||
base_model,
|
||||
model_config,
|
||||
cfg,
|
||||
quant_storage=quant_storage,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
skip_move_to_device = True
|
||||
elif (
|
||||
|
||||
@@ -393,10 +393,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
def setup_deepspeed_env(cfg, stage=None):
|
||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||
if cfg.bf16:
|
||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
||||
elif cfg.fp16:
|
||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
|
||||
if stage:
|
||||
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
||||
if stage == 3:
|
||||
@@ -444,6 +440,12 @@ def prepare_optim_env(cfg):
|
||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
|
||||
|
||||
|
||||
def prepare_opinionated_env(cfg):
|
||||
if cfg.qlora_sharded_model_loading:
|
||||
# model loading is forked after the tokenizer
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||
|
||||
Reference in New Issue
Block a user