various batch of fixes (#1785)

* various batch of fixes

* more tweaks

* fix autoawq requirement for torch flexibility

* simplify conditionals

* multi-node fixes wip

* bump transformers and include 405b qlora+fsdp yaml
This commit is contained in:
Wing Lian
2024-07-28 07:25:54 -04:00
committed by GitHub
parent 22680913f3
commit 94ba93259f
11 changed files with 253 additions and 99 deletions

View File

@@ -0,0 +1,62 @@
base_model: meta-llama/Meta-Llama-3.1-405B
tokenizer_type: AutoTokenizer
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out/qlora-llama3_1-405b
adapter: qlora
sequence_len: 1024
sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|finetune_right_pad_id|>

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.11.1 peft==0.11.1
transformers==4.43.1 transformers==4.43.3
tokenizers==0.19.1 tokenizers==0.19.1
bitsandbytes==0.43.1 bitsandbytes==0.43.1
accelerate==0.32.0 accelerate==0.32.0
@@ -32,6 +32,7 @@ fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e59
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard
python-dotenv==1.0.1 python-dotenv==1.0.1
autoawq>=0.2.5
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1

View File

@@ -2,6 +2,7 @@
CLI to run training on a model CLI to run training on a model
""" """
import logging import logging
import warnings
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -76,8 +77,12 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
if parsed_cli_args.download: if parsed_cli_args.download:
model_name = parsed_cfg.base_model model_name = parsed_cfg.base_model
with init_empty_weights(): with warnings.catch_warnings():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) # there are a bunch of useless UserWarnings about
# "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model"
warnings.simplefilter("ignore")
with init_empty_weights(include_buffers=True):
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
LOG.info( LOG.info(
Fore.GREEN Fore.GREEN

View File

@@ -0,0 +1,14 @@
"""
Common architecture specific constants
"""
MOE_ARCH_BLOCK = {
"dbrx": "DbrxFFN",
"jamba": "JambaSparseMoeBlock",
"jetmoe": [
"JetMoeMoA",
"JetMoeMoE",
],
"mixtral": "MixtralSparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
}

View File

@@ -8,6 +8,7 @@ import importlib
import importlib.util import importlib.util
import logging import logging
import math import math
import os
import sys import sys
from abc import abstractmethod from abc import abstractmethod
from collections import defaultdict from collections import defaultdict
@@ -28,7 +29,7 @@ from transformers import (
TrainerCallback, TrainerCallback,
TrainingArguments, TrainingArguments,
) )
from transformers.trainer_utils import seed_worker from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from trl import ( from trl import (
CPOConfig, CPOConfig,
@@ -286,7 +287,77 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
) )
class AxolotlTrainer(Trainer): class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
"""
args = None # type: AxolotlTrainingArguments
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
class AxolotlTrainer(SchedulerMixin, Trainer):
""" """
Extend the base Trainer for axolotl helpers Extend the base Trainer for axolotl helpers
""" """
@@ -404,68 +475,6 @@ class AxolotlTrainer(Trainer):
return self.optimizer return self.optimizer
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining: if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches: if self.args.multipack_real_batches:
@@ -830,6 +839,14 @@ class AxolotlTrainer(Trainer):
for key, value in metrics.items(): for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value) self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, metrics=None):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, metrics=metrics)
class AxolotlMambaTrainer(AxolotlTrainer): class AxolotlMambaTrainer(AxolotlTrainer):
""" """
@@ -929,7 +946,7 @@ class ReLoRATrainer(AxolotlTrainer):
return self.lr_scheduler return self.lr_scheduler
class AxolotlDPOTrainer(DPOTrainer): class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
""" """
Extend the base DPOTrainer for axolotl helpers Extend the base DPOTrainer for axolotl helpers
""" """
@@ -990,7 +1007,7 @@ class AxolotlDPOTrainer(DPOTrainer):
return res return res
class AxolotlORPOTrainer(ORPOTrainer): class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
""" """
Extend the base ORPOTrainer for axolotl helpers Extend the base ORPOTrainer for axolotl helpers
""" """
@@ -998,7 +1015,7 @@ class AxolotlORPOTrainer(ORPOTrainer):
tag_names = ["axolotl", "orpo"] tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(KTOTrainer): class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
""" """
Extend the base KTOTrainer for axolotl helpers Extend the base KTOTrainer for axolotl helpers
""" """
@@ -1006,7 +1023,7 @@ class AxolotlKTOTrainer(KTOTrainer):
tag_names = ["axolotl", "kto"] tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(CPOTrainer): class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
""" """
Extend the base CPOTrainer for axolotl helpers Extend the base CPOTrainer for axolotl helpers
""" """
@@ -1750,6 +1767,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl == "simpo": if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo" training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
if self.cfg.cpo_alpha is not None: if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha

View File

@@ -62,7 +62,7 @@ def default(
tokenize=False, tokenize=False,
) )
chosen_strip_index = result["chosen"].find(chosen["content"]) chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:] result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template( result["rejected"] = tokenizer.apply_chat_template(
[rejected], [rejected],
@@ -71,7 +71,7 @@ def default(
tokenize=False, tokenize=False,
) )
rejected_strip_index = result["rejected"].find(rejected["content"]) rejected_strip_index = result["rejected"].find(rejected["content"])
result["rejected"] = result["rejected"][rejected_strip_index:] result["rejected"] = result["rejected"][rejected_strip_index:].rstrip()
return result return result

View File

@@ -212,26 +212,23 @@ def train(
elif cfg.deepspeed and is_deepspeed_zero3_enabled(): elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone() trainer.accelerator.wait_for_everyone()
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped) trainer.save_model(cfg.output_dir)
# the trainer saved a model.safetensors file in the output directory, # the trainer saved a model.safetensors file in the output directory,
# but it is a proxy model and should be deleted # but it is most likely a proxy model and if so, should be deleted
if os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")): maybe_proxy = os.path.exists(os.path.join(cfg.output_dir, "model.safetensors"))
maybe_sharded = os.path.exists(
os.path.join(cfg.output_dir, "model.safetensors.index.json")
)
if maybe_proxy and maybe_sharded:
LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}") LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}")
LOG.info("This is a proxy model and should be deleted") LOG.info("This is a proxy model and should be deleted")
os.remove(os.path.join(cfg.output_dir, "model.safetensors")) try:
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError:
pass
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
# The model name saved is `pytorch_model.bin`
unwrapped_model.save_pretrained(
cfg.output_dir,
is_main_process=trainer.accelerator.is_main_process,
save_function=trainer.accelerator.save,
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
)
elif cfg.local_rank == 0: elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer: if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)

View File

@@ -42,7 +42,7 @@ from axolotl.prompters import (
from axolotl.utils.data.pretraining import wrap_pretraining_dataset from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.utils import md5 from axolotl.utils.data.utils import md5
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import ( from axolotl.utils.trainer import (
calculate_total_num_steps, calculate_total_num_steps,
process_datasets_for_packing, process_datasets_for_packing,
@@ -54,7 +54,7 @@ LOG = logging.getLogger("axolotl")
def prepare_dataset(cfg, tokenizer): def prepare_dataset(cfg, tokenizer):
prompters = [] prompters = []
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
with zero_first(is_main_process()): with zero_first(is_local_main_process()):
if cfg.test_datasets: if cfg.test_datasets:
train_dataset, _, prompters = load_prepare_datasets( train_dataset, _, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train" tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
@@ -170,6 +170,7 @@ def load_tokenized_prepared_datasets(
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
if dataset: if dataset:
# This is for the case where we already loaded a pretokenized dataset from the hub
... ...
elif ( elif (
cfg.dataset_prepared_path cfg.dataset_prepared_path
@@ -198,6 +199,8 @@ def load_tokenized_prepared_datasets(
def for_d_in_datasets(dataset_configs): def for_d_in_datasets(dataset_configs):
for dataset in dataset_configs: for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list): if dataset.name and isinstance(dataset.name, list):
# load_dataset doesn't properly handle multiple named configurations
# at the same time for a given dataset
for name in dataset.name: for name in dataset.name:
yield DictDefault({**dataset, "name": name}) yield DictDefault({**dataset, "name": name})
else: else:
@@ -208,6 +211,8 @@ def load_tokenized_prepared_datasets(
ds: Optional[Union[Dataset, DatasetDict]] = None ds: Optional[Union[Dataset, DatasetDict]] = None
ds_from_hub = False ds_from_hub = False
try: try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
load_dataset( load_dataset(
config_dataset.path, config_dataset.path,
name=config_dataset.name, name=config_dataset.name,

View File

@@ -44,6 +44,10 @@ def is_main_process():
return dist.get_rank() == 0 return dist.get_rank() == 0
def is_local_main_process():
return PartialState().is_main_process
def get_world_size(): def get_world_size():
return int(os.getenv("WORLD_SIZE", "1")) return int(os.getenv("WORLD_SIZE", "1"))

View File

@@ -29,6 +29,7 @@ from transformers import ( # noqa: F401
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
AwqConfig,
BitsAndBytesConfig, BitsAndBytesConfig,
GPTQConfig, GPTQConfig,
PreTrainedModel, PreTrainedModel,
@@ -36,6 +37,7 @@ from transformers import ( # noqa: F401
) )
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import ( from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES, SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -510,7 +512,25 @@ def load_model(
model_kwargs["quantization_config"] = GPTQConfig( model_kwargs["quantization_config"] = GPTQConfig(
**model_config.quantization_config **model_config.quantization_config
) )
if cfg.adapter == "qlora" and cfg.load_in_4bit: if (
cfg.adapter in ["qlora", "lora"]
and hasattr(model_config, "quantization_config")
and model_config.quantization_config["quant_method"]
in ["gptq", "awq", "bitsandbytes"]
):
if model_config.quantization_config["quant_method"] == "gptq":
model_kwargs["quantization_config"] = GPTQConfig(
**model_config.quantization_config
)
elif model_config.quantization_config["quant_method"] == "awq":
model_kwargs["quantization_config"] = AwqConfig(
**model_config.quantization_config
)
elif model_config.quantization_config["quant_method"] == "bitsandbytes":
model_kwargs["quantization_config"] = BitsAndBytesConfig(
**model_config.quantization_config
)
elif cfg.adapter == "qlora" and cfg.load_in_4bit:
bnb_config = { bnb_config = {
"load_in_4bit": True, "load_in_4bit": True,
"llm_int8_threshold": 6.0, "llm_int8_threshold": 6.0,
@@ -785,12 +805,14 @@ def load_model(
set_z3_leaf_modules, set_z3_leaf_modules,
) )
if cfg.model_config_type == "mixtral": if cfg.model_config_type in MOE_ARCH_BLOCK:
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock") set_z3_leaf_modules(
set_z3_leaf_modules(model, [moe_block]) model,
elif cfg.model_config_type == "dbrx": [
moe_block = get_module_class_from_name(model, "DbrxFFN") get_module_class_from_name(model, module_name)
set_z3_leaf_modules(model, [moe_block]) for module_name in MOE_ARCH_BLOCK[cfg.model_config_type]
],
)
if cfg.model_config_type == "qwen" and cfg.adapter == "lora": if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled # Qwen doesn't play nicely with LoRA if this is enabled
@@ -804,6 +826,9 @@ def load_model(
# make sure everything is in the same dtype # make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True skip_prepare_model_for_kbit_training = True
if is_deepspeed_zero3_enabled():
skip_prepare_model_for_kbit_training = True
if cfg.adapter in ["lora", "qlora"]: if cfg.adapter in ["lora", "qlora"]:
if cfg.gradient_checkpointing: if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable( model.gradient_checkpointing_enable(
@@ -838,6 +863,9 @@ def load_model(
else: else:
model, lora_config = load_adapter(model, cfg, cfg.adapter) model, lora_config = load_adapter(model, cfg, cfg.adapter)
if is_deepspeed_zero3_enabled():
skip_move_to_device = True
if ( if (
cfg.ddp cfg.ddp
and not load_in_8bit and not load_in_8bit

View File

@@ -1,4 +1,5 @@
"""Module containing the Trainer class and related functions""" """Module containing the Trainer class and related functions"""
import json
import math import math
import os import os
import random import random
@@ -389,6 +390,19 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
return total_num_steps return total_num_steps
def setup_deepspeed_env(cfg, stage=None):
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if cfg.bf16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
if stage:
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
def setup_fsdp_envs(cfg): def setup_fsdp_envs(cfg):
os.environ["ACCELERATE_USE_FSDP"] = "true" os.environ["ACCELERATE_USE_FSDP"] = "true"
if cfg.fsdp_config.fsdp_activation_checkpointing: if cfg.fsdp_config.fsdp_activation_checkpointing:
@@ -415,8 +429,14 @@ def prepare_optim_env(cfg):
if cfg.fsdp: if cfg.fsdp:
setup_fsdp_envs(cfg) setup_fsdp_envs(cfg)
elif cfg.deepspeed: elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" stage = None
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed # check if the cfg.deepspeed is a file
if os.path.isfile(cfg.deepspeed):
# parse with json
with open(cfg.deepspeed, "r", encoding="utf-8") as fin:
deepspeed_config = json.load(fin)
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
setup_deepspeed_env(cfg, stage=stage)
if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True: if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16" os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"