various batch of fixes (#1785)
* various batch of fixes * more tweaks * fix autoawq requirement for torch flexibility * simplify conditionals * multi-node fixes wip * bump transformers and include 405b qlora+fsdp yaml
This commit is contained in:
62
examples/llama-3/qlora-fsdp-405b.yaml
Normal file
62
examples/llama-3/qlora-fsdp-405b.yaml
Normal file
@@ -0,0 +1,62 @@
|
||||
base_model: meta-llama/Meta-Llama-3.1-405B
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out/qlora-llama3_1-405b
|
||||
|
||||
adapter: qlora
|
||||
|
||||
sequence_len: 1024
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.00001
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_limit_all_gathers: true
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_offload_params: true
|
||||
fsdp_use_orig_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
@@ -1,7 +1,7 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft==0.11.1
|
||||
transformers==4.43.1
|
||||
transformers==4.43.3
|
||||
tokenizers==0.19.1
|
||||
bitsandbytes==0.43.1
|
||||
accelerate==0.32.0
|
||||
@@ -32,6 +32,7 @@ fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e59
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
python-dotenv==1.0.1
|
||||
autoawq>=0.2.5
|
||||
|
||||
mamba-ssm==1.2.0.post1
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@@ -76,8 +77,12 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
|
||||
if parsed_cli_args.download:
|
||||
model_name = parsed_cfg.base_model
|
||||
with init_empty_weights():
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
with warnings.catch_warnings():
|
||||
# there are a bunch of useless UserWarnings about
|
||||
# "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model"
|
||||
warnings.simplefilter("ignore")
|
||||
with init_empty_weights(include_buffers=True):
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
LOG.info(
|
||||
Fore.GREEN
|
||||
|
||||
14
src/axolotl/common/architectures.py
Normal file
14
src/axolotl/common/architectures.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Common architecture specific constants
|
||||
"""
|
||||
|
||||
MOE_ARCH_BLOCK = {
|
||||
"dbrx": "DbrxFFN",
|
||||
"jamba": "JambaSparseMoeBlock",
|
||||
"jetmoe": [
|
||||
"JetMoeMoA",
|
||||
"JetMoeMoE",
|
||||
],
|
||||
"mixtral": "MixtralSparseMoeBlock",
|
||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import importlib
|
||||
import importlib.util
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
@@ -28,7 +29,7 @@ from transformers import (
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import (
|
||||
CPOConfig,
|
||||
@@ -286,7 +287,77 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
class SchedulerMixin(Trainer):
|
||||
"""
|
||||
Mixin class for scheduler setup in CausalTrainer.
|
||||
"""
|
||||
|
||||
args = None # type: AxolotlTrainingArguments
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
use_cosine_quadratic = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
)
|
||||
|
||||
use_cosine_min_lr = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.cosine_min_lr_ratio is not None
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if use_cosine_quadratic:
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
@@ -404,68 +475,6 @@ class AxolotlTrainer(Trainer):
|
||||
|
||||
return self.optimizer
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
use_cosine_quadratic = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
)
|
||||
|
||||
use_cosine_min_lr = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.cosine_min_lr_ratio is not None
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if use_cosine_quadratic:
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
if self.args.multipack_real_batches:
|
||||
@@ -830,6 +839,14 @@ class AxolotlTrainer(Trainer):
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
|
||||
def _save_checkpoint(self, model, trial, metrics=None):
|
||||
# make sure the checkpoint dir exists, since trainer is flakey
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
run_dir = self._get_output_dir(trial=trial)
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, metrics=metrics)
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""
|
||||
@@ -929,7 +946,7 @@ class ReLoRATrainer(AxolotlTrainer):
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(DPOTrainer):
|
||||
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
"""
|
||||
Extend the base DPOTrainer for axolotl helpers
|
||||
"""
|
||||
@@ -990,7 +1007,7 @@ class AxolotlDPOTrainer(DPOTrainer):
|
||||
return res
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(ORPOTrainer):
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
"""
|
||||
@@ -998,7 +1015,7 @@ class AxolotlORPOTrainer(ORPOTrainer):
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(KTOTrainer):
|
||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
@@ -1006,7 +1023,7 @@ class AxolotlKTOTrainer(KTOTrainer):
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(CPOTrainer):
|
||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
"""
|
||||
@@ -1750,6 +1767,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.rl == "simpo":
|
||||
training_args_cls = AxolotlCPOConfig
|
||||
training_args_kwargs["loss_type"] = "simpo"
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
|
||||
if self.cfg.cpo_alpha is not None:
|
||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||
|
||||
@@ -62,7 +62,7 @@ def default(
|
||||
tokenize=False,
|
||||
)
|
||||
chosen_strip_index = result["chosen"].find(chosen["content"])
|
||||
result["chosen"] = result["chosen"][chosen_strip_index:]
|
||||
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
|
||||
|
||||
result["rejected"] = tokenizer.apply_chat_template(
|
||||
[rejected],
|
||||
@@ -71,7 +71,7 @@ def default(
|
||||
tokenize=False,
|
||||
)
|
||||
rejected_strip_index = result["rejected"].find(rejected["content"])
|
||||
result["rejected"] = result["rejected"][rejected_strip_index:]
|
||||
result["rejected"] = result["rejected"][rejected_strip_index:].rstrip()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -212,26 +212,23 @@ def train(
|
||||
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
||||
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
|
||||
trainer.save_model(cfg.output_dir)
|
||||
|
||||
# the trainer saved a model.safetensors file in the output directory,
|
||||
# but it is a proxy model and should be deleted
|
||||
if os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")):
|
||||
# but it is most likely a proxy model and if so, should be deleted
|
||||
maybe_proxy = os.path.exists(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||
maybe_sharded = os.path.exists(
|
||||
os.path.join(cfg.output_dir, "model.safetensors.index.json")
|
||||
)
|
||||
|
||||
if maybe_proxy and maybe_sharded:
|
||||
LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}")
|
||||
LOG.info("This is a proxy model and should be deleted")
|
||||
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||
try:
|
||||
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
|
||||
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
|
||||
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
|
||||
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
|
||||
# The model name saved is `pytorch_model.bin`
|
||||
unwrapped_model.save_pretrained(
|
||||
cfg.output_dir,
|
||||
is_main_process=trainer.accelerator.is_main_process,
|
||||
save_function=trainer.accelerator.save,
|
||||
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
|
||||
)
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
@@ -42,7 +42,7 @@ from axolotl.prompters import (
|
||||
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||
from axolotl.utils.data.utils import md5
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
from axolotl.utils.distributed import is_local_main_process, zero_first
|
||||
from axolotl.utils.trainer import (
|
||||
calculate_total_num_steps,
|
||||
process_datasets_for_packing,
|
||||
@@ -54,7 +54,7 @@ LOG = logging.getLogger("axolotl")
|
||||
def prepare_dataset(cfg, tokenizer):
|
||||
prompters = []
|
||||
if not cfg.pretraining_dataset:
|
||||
with zero_first(is_main_process()):
|
||||
with zero_first(is_local_main_process()):
|
||||
if cfg.test_datasets:
|
||||
train_dataset, _, prompters = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
|
||||
@@ -170,6 +170,7 @@ def load_tokenized_prepared_datasets(
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
if dataset:
|
||||
# This is for the case where we already loaded a pretokenized dataset from the hub
|
||||
...
|
||||
elif (
|
||||
cfg.dataset_prepared_path
|
||||
@@ -198,6 +199,8 @@ def load_tokenized_prepared_datasets(
|
||||
def for_d_in_datasets(dataset_configs):
|
||||
for dataset in dataset_configs:
|
||||
if dataset.name and isinstance(dataset.name, list):
|
||||
# load_dataset doesn't properly handle multiple named configurations
|
||||
# at the same time for a given dataset
|
||||
for name in dataset.name:
|
||||
yield DictDefault({**dataset, "name": name})
|
||||
else:
|
||||
@@ -208,6 +211,8 @@ def load_tokenized_prepared_datasets(
|
||||
ds: Optional[Union[Dataset, DatasetDict]] = None
|
||||
ds_from_hub = False
|
||||
try:
|
||||
# this is just a basic check to see if the path is a
|
||||
# valid HF dataset that's loadable
|
||||
load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
|
||||
@@ -44,6 +44,10 @@ def is_main_process():
|
||||
return dist.get_rank() == 0
|
||||
|
||||
|
||||
def is_local_main_process():
|
||||
return PartialState().is_main_process
|
||||
|
||||
|
||||
def get_world_size():
|
||||
return int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from transformers import ( # noqa: F401
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
AwqConfig,
|
||||
BitsAndBytesConfig,
|
||||
GPTQConfig,
|
||||
PreTrainedModel,
|
||||
@@ -36,6 +37,7 @@ from transformers import ( # noqa: F401
|
||||
)
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||
@@ -510,7 +512,25 @@ def load_model(
|
||||
model_kwargs["quantization_config"] = GPTQConfig(
|
||||
**model_config.quantization_config
|
||||
)
|
||||
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
||||
if (
|
||||
cfg.adapter in ["qlora", "lora"]
|
||||
and hasattr(model_config, "quantization_config")
|
||||
and model_config.quantization_config["quant_method"]
|
||||
in ["gptq", "awq", "bitsandbytes"]
|
||||
):
|
||||
if model_config.quantization_config["quant_method"] == "gptq":
|
||||
model_kwargs["quantization_config"] = GPTQConfig(
|
||||
**model_config.quantization_config
|
||||
)
|
||||
elif model_config.quantization_config["quant_method"] == "awq":
|
||||
model_kwargs["quantization_config"] = AwqConfig(
|
||||
**model_config.quantization_config
|
||||
)
|
||||
elif model_config.quantization_config["quant_method"] == "bitsandbytes":
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**model_config.quantization_config
|
||||
)
|
||||
elif cfg.adapter == "qlora" and cfg.load_in_4bit:
|
||||
bnb_config = {
|
||||
"load_in_4bit": True,
|
||||
"llm_int8_threshold": 6.0,
|
||||
@@ -785,12 +805,14 @@ def load_model(
|
||||
set_z3_leaf_modules,
|
||||
)
|
||||
|
||||
if cfg.model_config_type == "mixtral":
|
||||
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
|
||||
set_z3_leaf_modules(model, [moe_block])
|
||||
elif cfg.model_config_type == "dbrx":
|
||||
moe_block = get_module_class_from_name(model, "DbrxFFN")
|
||||
set_z3_leaf_modules(model, [moe_block])
|
||||
if cfg.model_config_type in MOE_ARCH_BLOCK:
|
||||
set_z3_leaf_modules(
|
||||
model,
|
||||
[
|
||||
get_module_class_from_name(model, module_name)
|
||||
for module_name in MOE_ARCH_BLOCK[cfg.model_config_type]
|
||||
],
|
||||
)
|
||||
|
||||
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||
@@ -804,6 +826,9 @@ def load_model(
|
||||
# make sure everything is in the same dtype
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if cfg.adapter in ["lora", "qlora"]:
|
||||
if cfg.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(
|
||||
@@ -838,6 +863,9 @@ def load_model(
|
||||
else:
|
||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
skip_move_to_device = True
|
||||
|
||||
if (
|
||||
cfg.ddp
|
||||
and not load_in_8bit
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Module containing the Trainer class and related functions"""
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
@@ -389,6 +390,19 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
return total_num_steps
|
||||
|
||||
|
||||
def setup_deepspeed_env(cfg, stage=None):
|
||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||
if cfg.bf16:
|
||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
||||
elif cfg.fp16:
|
||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
|
||||
if stage:
|
||||
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
||||
if stage == 3:
|
||||
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
|
||||
|
||||
|
||||
def setup_fsdp_envs(cfg):
|
||||
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||
if cfg.fsdp_config.fsdp_activation_checkpointing:
|
||||
@@ -415,8 +429,14 @@ def prepare_optim_env(cfg):
|
||||
if cfg.fsdp:
|
||||
setup_fsdp_envs(cfg)
|
||||
elif cfg.deepspeed:
|
||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||
stage = None
|
||||
# check if the cfg.deepspeed is a file
|
||||
if os.path.isfile(cfg.deepspeed):
|
||||
# parse with json
|
||||
with open(cfg.deepspeed, "r", encoding="utf-8") as fin:
|
||||
deepspeed_config = json.load(fin)
|
||||
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
|
||||
setup_deepspeed_env(cfg, stage=stage)
|
||||
|
||||
if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
|
||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
||||
|
||||
Reference in New Issue
Block a user