qlora-fsdp ram efficient loading with hf trainer (#1791)
* fix 405b with lower cpu ram requirements * make sure to use doouble quant and only skip output embeddings * set model attributes * more fixes for sharded fsdp loading * update the base model in example to use pre-quantized nf4-bf16 weights * upstream fixes for qlora+fsdp
This commit is contained in:
@@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG
|
|||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
|
||||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG
|
|||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
|
||||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: meta-llama/Meta-Llama-3.1-405B
|
base_model: hugging-quants/Meta-Llama-3.1-405B-BNB-NF4-BF16
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
load_in_4bit: true
|
load_in_4bit: true
|
||||||
@@ -10,10 +10,11 @@ datasets:
|
|||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/out/qlora-llama3_1-405b
|
output_dir: ./outputs/out/qlora-llama3_1-405b
|
||||||
|
save_safetensors: true
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
|
|
||||||
sequence_len: 1024
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
@@ -25,7 +26,7 @@ lora_target_linear: true
|
|||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 2
|
||||||
optimizer: adamw_torch
|
optimizer: adamw_torch
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.00001
|
learning_rate: 0.00001
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.11.1
|
peft==0.11.1
|
||||||
transformers==4.43.3
|
transformers @ git+https://github.com/huggingface/transformers.git@026a173a64372e9602a16523b8fae9de4b0ff428
|
||||||
tokenizers==0.19.1
|
tokenizers==0.19.1
|
||||||
bitsandbytes==0.43.1
|
bitsandbytes==0.43.3
|
||||||
accelerate==0.32.0
|
accelerate==0.32.0
|
||||||
deepspeed==0.14.4
|
deepspeed==0.14.4
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ from axolotl.utils.distributed import is_main_process
|
|||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
from axolotl.utils.trainer import prepare_optim_env
|
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
@@ -382,6 +382,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
prepare_optim_env(cfg)
|
prepare_optim_env(cfg)
|
||||||
|
|
||||||
|
prepare_opinionated_env(cfg)
|
||||||
|
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|
||||||
normalize_cfg_datasets(cfg)
|
normalize_cfg_datasets(cfg)
|
||||||
|
|||||||
@@ -1243,7 +1243,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||||
if self.cfg.fsdp_config:
|
if self.cfg.fsdp_config:
|
||||||
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
|
training_arguments_kwargs["fsdp_config"] = {
|
||||||
|
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
|
||||||
|
}
|
||||||
|
|
||||||
if self.cfg.adapter == "qlora":
|
if self.cfg.adapter == "qlora":
|
||||||
training_arguments_kwargs["qlora"] = True
|
training_arguments_kwargs["qlora"] = True
|
||||||
|
|||||||
@@ -235,6 +235,12 @@ class LoraConfig(BaseModel):
|
|||||||
peft_use_rslora: Optional[bool] = None
|
peft_use_rslora: Optional[bool] = None
|
||||||
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
||||||
|
|
||||||
|
qlora_sharded_model_loading: Optional[bool] = Field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "load qlora model in sharded format for FSDP using answer.ai technique."
|
||||||
|
},
|
||||||
|
)
|
||||||
lora_on_cpu: Optional[bool] = None
|
lora_on_cpu: Optional[bool] = None
|
||||||
gptq: Optional[bool] = None
|
gptq: Optional[bool] = None
|
||||||
bnb_config_kwargs: Optional[Dict[str, Any]] = None
|
bnb_config_kwargs: Optional[Dict[str, Any]] = None
|
||||||
@@ -939,6 +945,8 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_eval_packing(cls, data):
|
def check_eval_packing(cls, data):
|
||||||
|
# TODO also should check test_datasets and val_set_size as we can skip
|
||||||
|
# if there are no eval datasets/splits
|
||||||
if (
|
if (
|
||||||
data.get("sample_packing")
|
data.get("sample_packing")
|
||||||
and data.get("eval_table_size")
|
and data.get("eval_table_size")
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from fastcore.parallel import parallel
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
from transformers.quantizers import AutoHfQuantizer
|
||||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
|
||||||
|
|
||||||
|
|
||||||
@@ -173,6 +174,7 @@ def load_sharded_model_quant(
|
|||||||
low_memory=True,
|
low_memory=True,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
loading_workers=2,
|
loading_workers=2,
|
||||||
|
quantization_config=None,
|
||||||
):
|
):
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForCausalLM.from_config(
|
model = AutoModelForCausalLM.from_config(
|
||||||
@@ -186,15 +188,26 @@ def load_sharded_model_quant(
|
|||||||
compute_dtype=compute_dtype,
|
compute_dtype=compute_dtype,
|
||||||
quant_type="nf4",
|
quant_type="nf4",
|
||||||
quant_storage=quant_storage,
|
quant_storage=quant_storage,
|
||||||
|
compress_statistics=True, # bnb_4bit_use_double_quant
|
||||||
|
skip_modules=[
|
||||||
|
"lm_head",
|
||||||
|
"embed_out",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# this is the more common case with HF transformers
|
# this is the more common case with HF transformers
|
||||||
|
# TODO can we detect the model arch and dynamically set skip_modules
|
||||||
model.model = _replace_linear(
|
model.model = _replace_linear(
|
||||||
model.model,
|
model.model,
|
||||||
Linear4bit,
|
Linear4bit,
|
||||||
compute_dtype=compute_dtype,
|
compute_dtype=compute_dtype,
|
||||||
quant_type="nf4",
|
quant_type="nf4",
|
||||||
quant_storage=quant_storage,
|
quant_storage=quant_storage,
|
||||||
|
compress_statistics=True, # bnb_4bit_use_double_quant
|
||||||
|
skip_modules=[
|
||||||
|
"lm_head",
|
||||||
|
"embed_out",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
model.is_loaded_in_4bit = True
|
model.is_loaded_in_4bit = True
|
||||||
|
|
||||||
@@ -251,6 +264,11 @@ def load_sharded_model_quant(
|
|||||||
quant_method=quant_method,
|
quant_method=quant_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# these attributes are needed to inform transformers/peft of the quantization
|
||||||
|
model.is_quantized = True
|
||||||
|
model.quantization_method = "bitsandbytes"
|
||||||
|
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
|
||||||
|
|
||||||
if cfg.local_rank == 0 and verbose:
|
if cfg.local_rank == 0 and verbose:
|
||||||
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
|
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
|
||||||
# cleanup any extra memory usage from parallel loading
|
# cleanup any extra memory usage from parallel loading
|
||||||
|
|||||||
@@ -624,14 +624,21 @@ def load_model(
|
|||||||
elif (
|
elif (
|
||||||
qlora_fsdp
|
qlora_fsdp
|
||||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||||
and cfg.model_config_type == "dbrx"
|
and (cfg.model_config_type == "dbrx" or cfg.qlora_sharded_model_loading)
|
||||||
):
|
):
|
||||||
quant_storage = cfg.torch_dtype
|
quant_storage = cfg.torch_dtype
|
||||||
|
quantization_config = hasattr(
|
||||||
|
model_config, "quantization_config"
|
||||||
|
) and getattr(model_config, "quantization_config")
|
||||||
|
quantization_config = (
|
||||||
|
quantization_config or model_kwargs["quantization_config"]
|
||||||
|
)
|
||||||
model = load_sharded_model_quant(
|
model = load_sharded_model_quant(
|
||||||
base_model,
|
base_model,
|
||||||
model_config,
|
model_config,
|
||||||
cfg,
|
cfg,
|
||||||
quant_storage=quant_storage,
|
quant_storage=quant_storage,
|
||||||
|
quantization_config=quantization_config,
|
||||||
)
|
)
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
elif (
|
elif (
|
||||||
|
|||||||
@@ -393,10 +393,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
def setup_deepspeed_env(cfg, stage=None):
|
def setup_deepspeed_env(cfg, stage=None):
|
||||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||||
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||||
if cfg.bf16:
|
|
||||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
|
||||||
elif cfg.fp16:
|
|
||||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
|
|
||||||
if stage:
|
if stage:
|
||||||
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
||||||
if stage == 3:
|
if stage == 3:
|
||||||
@@ -444,6 +440,12 @@ def prepare_optim_env(cfg):
|
|||||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_opinionated_env(cfg):
|
||||||
|
if cfg.qlora_sharded_model_loading:
|
||||||
|
# model loading is forked after the tokenizer
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||||
|
|||||||
Reference in New Issue
Block a user