Compare commits
10 Commits
mistral-su
...
tensor-par
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
87e8f13056 | ||
|
|
026172eaa8 | ||
|
|
b3689f73e3 | ||
|
|
c4664ba8ee | ||
|
|
75e4fc2825 | ||
|
|
e13c2fd6b1 | ||
|
|
8a21e14a21 | ||
|
|
9c52a83403 | ||
|
|
fb8ee37ca6 | ||
|
|
65f3a4f703 |
@@ -31,3 +31,4 @@ scikit-learn==1.2.2
|
|||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
fschat==0.2.29
|
fschat==0.2.29
|
||||||
|
tensor_parallel
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from functools import partial
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import tensor_parallel as tp
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
@@ -33,6 +34,7 @@ from axolotl.utils.callbacks import (
|
|||||||
)
|
)
|
||||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||||
|
from axolotl.utils.distributed import is_distributed
|
||||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -102,6 +104,9 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
bench_source_max_len: int = field(
|
bench_source_max_len: int = field(
|
||||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||||
)
|
)
|
||||||
|
tensor_parallel: bool = field(
|
||||||
|
default=False, metadata={"help": "Use tensor parallelism to train"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
class AxolotlTrainer(Trainer):
|
||||||
@@ -246,6 +251,14 @@ class AxolotlTrainer(Trainer):
|
|||||||
# return (loss, outputs) if return_outputs else loss
|
# return (loss, outputs) if return_outputs else loss
|
||||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||||
|
|
||||||
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
|
if self.args.tensor_parallel:
|
||||||
|
model = tp.tensor_parallel(model, distributed=is_distributed())
|
||||||
|
model.hf_device_map = tp.infer_sharded_device_map(model)
|
||||||
|
else:
|
||||||
|
model = super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -371,7 +384,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return trainer_kwargs, trainer_cls
|
return trainer_kwargs, trainer_cls
|
||||||
|
|
||||||
def hook_post_create_trainer(self, trainer):
|
def hook_post_create_trainer(self, trainer):
|
||||||
# TODO
|
if self.cfg.tensor_parallel:
|
||||||
|
trainer.model = trainer.accelerator.prepare_model(
|
||||||
|
trainer.model, device_placement=True
|
||||||
|
)
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
@@ -615,6 +631,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
] = self.cfg.micro_batch_size
|
] = self.cfg.micro_batch_size
|
||||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||||
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
||||||
|
training_arguments_kwargs["tensor_parallel"] = self.cfg.tensor_parallel is True
|
||||||
|
|
||||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||||
training_arguments_kwargs
|
training_arguments_kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
"""Benchmarking and measurement utilities"""
|
"""Benchmarking and measurement utilities"""
|
||||||
import functools
|
import functools
|
||||||
|
import logging
|
||||||
|
|
||||||
import pynvml
|
import pynvml
|
||||||
import torch
|
import torch
|
||||||
from pynvml.nvml import NVMLError
|
from pynvml.nvml import NVMLError
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.utils.bench")
|
||||||
|
|
||||||
|
|
||||||
def check_cuda_device(default_value):
|
def check_cuda_device(default_value):
|
||||||
"""
|
"""
|
||||||
@@ -62,7 +65,14 @@ def gpu_memory_usage_smi(device=0):
|
|||||||
|
|
||||||
|
|
||||||
def log_gpu_memory_usage(log, msg, device):
|
def log_gpu_memory_usage(log, msg, device):
|
||||||
usage, cache, misc = gpu_memory_usage_all(device)
|
if not torch.cuda.is_available():
|
||||||
|
return (0, 0, 0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
usage, cache, misc = gpu_memory_usage_all(device)
|
||||||
|
except ValueError as exc:
|
||||||
|
LOG.exception(exc)
|
||||||
|
return (0, 0, 0)
|
||||||
extras = []
|
extras = []
|
||||||
if cache > 0:
|
if cache > 0:
|
||||||
extras.append(f"+{cache:.03f}GB cache")
|
extras.append(f"+{cache:.03f}GB cache")
|
||||||
|
|||||||
@@ -369,6 +369,10 @@ def validate_config(cfg):
|
|||||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.tensor_parallel and cfg.gradient_checkpointing:
|
||||||
|
raise ValueError(
|
||||||
|
"TensorParallelPreTrainedModel does not support gradient checkpointing"
|
||||||
|
)
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import Optional, Tuple # noqa: F401
|
|||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
import transformers.utils.bitsandbytes
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from peft import PeftConfig, prepare_model_for_kbit_training
|
from peft import PeftConfig, prepare_model_for_kbit_training
|
||||||
from peft.tuners.lora import QuantLinear
|
from peft.tuners.lora import QuantLinear
|
||||||
@@ -221,7 +222,7 @@ def load_model(
|
|||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
llm_int8_threshold=6.0,
|
llm_int8_threshold=6.0,
|
||||||
llm_int8_has_fp16_weight=False,
|
llm_int8_has_fp16_weight=False,
|
||||||
bnb_4bit_compute_dtype=cfg.torch_dtype,
|
bnb_4bit_compute_dtype=torch.float16,
|
||||||
bnb_4bit_use_double_quant=True,
|
bnb_4bit_use_double_quant=True,
|
||||||
bnb_4bit_quant_type="nf4",
|
bnb_4bit_quant_type="nf4",
|
||||||
)
|
)
|
||||||
@@ -235,7 +236,12 @@ def load_model(
|
|||||||
model_kwargs["use_flash_attention_2"] = True
|
model_kwargs["use_flash_attention_2"] = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
if (
|
||||||
|
cfg.is_llama_derived_model
|
||||||
|
and not cfg.trust_remote_code
|
||||||
|
and not cfg.gptq
|
||||||
|
and not cfg.tensor_parallel
|
||||||
|
):
|
||||||
from transformers import LlamaForCausalLM
|
from transformers import LlamaForCausalLM
|
||||||
|
|
||||||
config_kwargs = {}
|
config_kwargs = {}
|
||||||
@@ -301,7 +307,7 @@ def load_model(
|
|||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
elif model_type and not cfg.trust_remote_code:
|
elif model_type and not cfg.trust_remote_code and not cfg.tensor_parallel:
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
@@ -316,6 +322,17 @@ def load_model(
|
|||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
elif cfg.tensor_parallel:
|
||||||
|
model_kwargs.pop("device_map")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
offload_state_dict=True,
|
||||||
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
@@ -366,15 +383,18 @@ def load_model(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings_len = (
|
try:
|
||||||
math.ceil(len(tokenizer) / 32) * 32
|
embeddings_len = (
|
||||||
if cfg.resize_token_embeddings_to_32x
|
math.ceil(len(tokenizer) / 32) * 32
|
||||||
else len(tokenizer)
|
if cfg.resize_token_embeddings_to_32x
|
||||||
)
|
else len(tokenizer)
|
||||||
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
)
|
||||||
model.resize_token_embeddings(embeddings_len)
|
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
||||||
else:
|
model.resize_token_embeddings(embeddings_len)
|
||||||
model.tie_weights()
|
else:
|
||||||
|
model.tie_weights()
|
||||||
|
except NotImplementedError:
|
||||||
|
LOG.warning("`resize_token_embeddings` not implemented on model")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model.config, "max_position_embeddings")
|
hasattr(model.config, "max_position_embeddings")
|
||||||
@@ -477,7 +497,12 @@ def load_adapter(model, cfg, adapter, inference=False):
|
|||||||
if adapter is None:
|
if adapter is None:
|
||||||
return model, None
|
return model, None
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
model.enable_input_require_grads()
|
try:
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
except NotImplementedError:
|
||||||
|
LOG.warning("enable_input_require_grads not implemented on model")
|
||||||
|
if adapter == "qlora" and cfg.tensor_parallel:
|
||||||
|
model, _ = load_tp_qlora(model)
|
||||||
if adapter in ["lora", "qlora"]:
|
if adapter in ["lora", "qlora"]:
|
||||||
return load_lora(model, cfg, inference=inference)
|
return load_lora(model, cfg, inference=inference)
|
||||||
if adapter == "llama-adapter":
|
if adapter == "llama-adapter":
|
||||||
@@ -529,6 +554,25 @@ def find_all_linear_names(model):
|
|||||||
return list(lora_module_names)
|
return list(lora_module_names)
|
||||||
|
|
||||||
|
|
||||||
|
def load_tp_qlora(model):
|
||||||
|
from transformers.utils.bitsandbytes import replace_with_bnb_linear
|
||||||
|
|
||||||
|
model = replace_with_bnb_linear(
|
||||||
|
model,
|
||||||
|
quantization_config=BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
llm_int8_threshold=6.0,
|
||||||
|
llm_int8_has_fp16_weight=False,
|
||||||
|
bnb_4bit_compute_dtype=torch.float16,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model.is_loaded_in_4bit = True
|
||||||
|
|
||||||
|
return model, None
|
||||||
|
|
||||||
|
|
||||||
def load_lora(model, cfg, inference=False):
|
def load_lora(model, cfg, inference=False):
|
||||||
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user