Compare commits

...

10 Commits

Author SHA1 Message Date
Wing Lian
87e8f13056 repalce linear layers for qlora as well as add peft 2023-11-01 22:31:02 -04:00
Wing Lian
026172eaa8 remove unused code, support adapter for tensor parallel 2023-11-01 20:31:51 -04:00
Wing Lian
b3689f73e3 chore: lint 2023-11-01 20:25:10 -04:00
Wing Lian
c4664ba8ee tp fixes 2023-11-01 18:50:18 -04:00
Wing Lian
75e4fc2825 wip more tp fixes 2023-11-01 01:45:36 -04:00
Wing Lian
e13c2fd6b1 getting better 2023-10-31 22:23:40 -04:00
Wing Lian
8a21e14a21 load to cpu first 2023-10-31 22:23:15 -04:00
Wing Lian
9c52a83403 load model faster w low_cpu_mem_usage 2023-10-31 22:23:15 -04:00
Wing Lian
fb8ee37ca6 wip tp 2023-10-31 22:23:14 -04:00
Wing Lian
65f3a4f703 tensor-parallel support 2023-10-31 22:21:40 -04:00
5 changed files with 92 additions and 15 deletions

View File

@@ -31,3 +31,4 @@ scikit-learn==1.2.2
pynvml pynvml
art art
fschat==0.2.29 fschat==0.2.29
tensor_parallel

View File

@@ -14,6 +14,7 @@ from functools import partial
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
import tensor_parallel as tp
import torch import torch
import transformers import transformers
from datasets import Dataset from datasets import Dataset
@@ -33,6 +34,7 @@ from axolotl.utils.callbacks import (
) )
from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.distributed import is_distributed
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
try: try:
@@ -102,6 +104,9 @@ class AxolotlTrainingArguments(TrainingArguments):
bench_source_max_len: int = field( bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."} default=2048, metadata={"help": "Maximum source sequence length for bench."}
) )
tensor_parallel: bool = field(
default=False, metadata={"help": "Use tensor parallelism to train"}
)
class AxolotlTrainer(Trainer): class AxolotlTrainer(Trainer):
@@ -246,6 +251,14 @@ class AxolotlTrainer(Trainer):
# return (loss, outputs) if return_outputs else loss # return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs) return super().compute_loss(model, inputs, return_outputs=return_outputs)
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.tensor_parallel:
model = tp.tensor_parallel(model, distributed=is_distributed())
model.hf_device_map = tp.infer_sharded_device_map(model)
else:
model = super()._wrap_model(model, training=training, dataloader=dataloader)
return model
class OneCycleLRSchedulerTrainer(AxolotlTrainer): class OneCycleLRSchedulerTrainer(AxolotlTrainer):
""" """
@@ -371,7 +384,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return trainer_kwargs, trainer_cls return trainer_kwargs, trainer_cls
def hook_post_create_trainer(self, trainer): def hook_post_create_trainer(self, trainer):
# TODO if self.cfg.tensor_parallel:
trainer.model = trainer.accelerator.prepare_model(
trainer.model, device_placement=True
)
return trainer return trainer
def get_callbacks(self): def get_callbacks(self):
@@ -615,6 +631,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.micro_batch_size ] = self.cfg.micro_batch_size
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
training_arguments_kwargs["tensor_parallel"] = self.cfg.tensor_parallel is True
training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs training_arguments_kwargs
) )

View File

@@ -1,10 +1,13 @@
"""Benchmarking and measurement utilities""" """Benchmarking and measurement utilities"""
import functools import functools
import logging
import pynvml import pynvml
import torch import torch
from pynvml.nvml import NVMLError from pynvml.nvml import NVMLError
LOG = logging.getLogger("axolotl.utils.bench")
def check_cuda_device(default_value): def check_cuda_device(default_value):
""" """
@@ -62,7 +65,14 @@ def gpu_memory_usage_smi(device=0):
def log_gpu_memory_usage(log, msg, device): def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available():
return (0, 0, 0)
try:
usage, cache, misc = gpu_memory_usage_all(device) usage, cache, misc = gpu_memory_usage_all(device)
except ValueError as exc:
LOG.exception(exc)
return (0, 0, 0)
extras = [] extras = []
if cache > 0: if cache > 0:
extras.append(f"+{cache:.03f}GB cache") extras.append(f"+{cache:.03f}GB cache")

View File

@@ -369,6 +369,10 @@ def validate_config(cfg):
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit." "If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
) )
if cfg.tensor_parallel and cfg.gradient_checkpointing:
raise ValueError(
"TensorParallelPreTrainedModel does not support gradient checkpointing"
)
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -7,6 +7,7 @@ from typing import Optional, Tuple # noqa: F401
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
import transformers.utils.bitsandbytes
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig, prepare_model_for_kbit_training from peft import PeftConfig, prepare_model_for_kbit_training
from peft.tuners.lora import QuantLinear from peft.tuners.lora import QuantLinear
@@ -221,7 +222,7 @@ def load_model(
load_in_4bit=True, load_in_4bit=True,
llm_int8_threshold=6.0, llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False, llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=cfg.torch_dtype, bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True, bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4", bnb_4bit_quant_type="nf4",
) )
@@ -235,7 +236,12 @@ def load_model(
model_kwargs["use_flash_attention_2"] = True model_kwargs["use_flash_attention_2"] = True
try: try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: if (
cfg.is_llama_derived_model
and not cfg.trust_remote_code
and not cfg.gptq
and not cfg.tensor_parallel
):
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
config_kwargs = {} config_kwargs = {}
@@ -301,7 +307,7 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs, **model_kwargs,
) )
elif model_type and not cfg.trust_remote_code: elif model_type and not cfg.trust_remote_code and not cfg.tensor_parallel:
if cfg.gptq: if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
@@ -316,6 +322,17 @@ def load_model(
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
elif cfg.tensor_parallel:
model_kwargs.pop("device_map")
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
low_cpu_mem_usage=True,
offload_state_dict=True,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else: else:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
base_model, base_model,
@@ -366,6 +383,7 @@ def load_model(
**model_kwargs, **model_kwargs,
) )
try:
embeddings_len = ( embeddings_len = (
math.ceil(len(tokenizer) / 32) * 32 math.ceil(len(tokenizer) / 32) * 32
if cfg.resize_token_embeddings_to_32x if cfg.resize_token_embeddings_to_32x
@@ -375,6 +393,8 @@ def load_model(
model.resize_token_embeddings(embeddings_len) model.resize_token_embeddings(embeddings_len)
else: else:
model.tie_weights() model.tie_weights()
except NotImplementedError:
LOG.warning("`resize_token_embeddings` not implemented on model")
if ( if (
hasattr(model.config, "max_position_embeddings") hasattr(model.config, "max_position_embeddings")
@@ -477,7 +497,12 @@ def load_adapter(model, cfg, adapter, inference=False):
if adapter is None: if adapter is None:
return model, None return model, None
if hasattr(model, "enable_input_require_grads"): if hasattr(model, "enable_input_require_grads"):
try:
model.enable_input_require_grads() model.enable_input_require_grads()
except NotImplementedError:
LOG.warning("enable_input_require_grads not implemented on model")
if adapter == "qlora" and cfg.tensor_parallel:
model, _ = load_tp_qlora(model)
if adapter in ["lora", "qlora"]: if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference) return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter": if adapter == "llama-adapter":
@@ -529,6 +554,25 @@ def find_all_linear_names(model):
return list(lora_module_names) return list(lora_module_names)
def load_tp_qlora(model):
from transformers.utils.bitsandbytes import replace_with_bnb_linear
model = replace_with_bnb_linear(
model,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
),
)
model.is_loaded_in_4bit = True
return model, None
def load_lora(model, cfg, inference=False): def load_lora(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] # type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]