From 94ba93259f421d438e53117267b17097c48cdd65 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 28 Jul 2024 07:25:54 -0400 Subject: [PATCH] various batch of fixes (#1785) * various batch of fixes * more tweaks * fix autoawq requirement for torch flexibility * simplify conditionals * multi-node fixes wip * bump transformers and include 405b qlora+fsdp yaml --- examples/llama-3/qlora-fsdp-405b.yaml | 62 +++++++ requirements.txt | 3 +- src/axolotl/cli/preprocess.py | 9 +- src/axolotl/common/architectures.py | 14 ++ src/axolotl/core/trainer_builder.py | 154 ++++++++++-------- .../prompt_strategies/dpo/chat_template.py | 4 +- src/axolotl/train.py | 27 ++- src/axolotl/utils/data/sft.py | 9 +- src/axolotl/utils/distributed.py | 4 + src/axolotl/utils/models.py | 42 ++++- src/axolotl/utils/trainer.py | 24 ++- 11 files changed, 253 insertions(+), 99 deletions(-) create mode 100644 examples/llama-3/qlora-fsdp-405b.yaml create mode 100644 src/axolotl/common/architectures.py diff --git a/examples/llama-3/qlora-fsdp-405b.yaml b/examples/llama-3/qlora-fsdp-405b.yaml new file mode 100644 index 000000000..385b7f91d --- /dev/null +++ b/examples/llama-3/qlora-fsdp-405b.yaml @@ -0,0 +1,62 @@ +base_model: meta-llama/Meta-Llama-3.1-405B +tokenizer_type: AutoTokenizer + +load_in_4bit: true +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out/qlora-llama3_1-405b + +adapter: qlora + +sequence_len: 1024 +sample_packing: true +pad_to_sequence_len: true + +lora_r: 16 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: +lora_target_linear: true + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 0.00001 + +train_on_inputs: false +group_by_length: false +bf16: true +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: true +logging_steps: 1 +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: 4 +saves_per_epoch: 1 +weight_decay: 0.0 +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: true + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD +special_tokens: + pad_token: <|finetune_right_pad_id|> diff --git a/requirements.txt b/requirements.txt index 981a62558..5825ee190 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.11.1 -transformers==4.43.1 +transformers==4.43.3 tokenizers==0.19.1 bitsandbytes==0.43.1 accelerate==0.32.0 @@ -32,6 +32,7 @@ fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e59 gradio==3.50.2 tensorboard python-dotenv==1.0.1 +autoawq>=0.2.5 mamba-ssm==1.2.0.post1 diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 5ec279d4b..e0dd7c2dc 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -2,6 +2,7 @@ CLI to run training on a model """ import logging +import warnings from pathlib import Path from typing import Union @@ -76,8 +77,12 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): if parsed_cli_args.download: model_name = parsed_cfg.base_model - with init_empty_weights(): - AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + with warnings.catch_warnings(): + # there are a bunch of useless UserWarnings about + # "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model" + warnings.simplefilter("ignore") + with init_empty_weights(include_buffers=True): + AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) LOG.info( Fore.GREEN diff --git a/src/axolotl/common/architectures.py b/src/axolotl/common/architectures.py new file mode 100644 index 000000000..7610b335a --- /dev/null +++ b/src/axolotl/common/architectures.py @@ -0,0 +1,14 @@ +""" +Common architecture specific constants +""" + +MOE_ARCH_BLOCK = { + "dbrx": "DbrxFFN", + "jamba": "JambaSparseMoeBlock", + "jetmoe": [ + "JetMoeMoA", + "JetMoeMoE", + ], + "mixtral": "MixtralSparseMoeBlock", + "qwen2_moe": "Qwen2MoeSparseMoeBlock", +} diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 9a12c5a06..ff4804b10 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -8,6 +8,7 @@ import importlib import importlib.util import logging import math +import os import sys from abc import abstractmethod from collections import defaultdict @@ -28,7 +29,7 @@ from transformers import ( TrainerCallback, TrainingArguments, ) -from transformers.trainer_utils import seed_worker +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker from transformers.utils import is_sagemaker_mp_enabled from trl import ( CPOConfig, @@ -286,7 +287,77 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig): ) -class AxolotlTrainer(Trainer): +class SchedulerMixin(Trainer): + """ + Mixin class for scheduler setup in CausalTrainer. + """ + + args = None # type: AxolotlTrainingArguments + + def create_scheduler( + self, num_training_steps: int, optimizer: torch.optim.Optimizer = None + ): + """ + Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or + passed as an argument. + + Args: + num_training_steps (int): The number of training steps to do. + optimizer (torch.optim.Optimizer): The training optimizer + """ + use_cosine_quadratic = ( + self.args.lr_scheduler_type == "cosine" + and self.args.lr_quadratic_warmup is True + ) + + use_cosine_min_lr = ( + self.args.lr_scheduler_type == "cosine" + and self.args.cosine_min_lr_ratio is not None + ) + + # fmt: off + if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition + # fmt: on + if use_cosine_quadratic: + if use_cosine_min_lr: + LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") + + self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + ) + elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr: + assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" + assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0" + self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + min_lr_ratio=self.args.cosine_min_lr_ratio, + constant_lr_ratio=self.args.cosine_constant_lr_ratio, + ) + elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: + assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" + self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + min_lr_ratio=self.args.cosine_min_lr_ratio, + ) + else: + return super().create_scheduler(num_training_steps, optimizer) + else: + if use_cosine_quadratic: + LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") + + if use_cosine_min_lr: + LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") + + return self.lr_scheduler + + +class AxolotlTrainer(SchedulerMixin, Trainer): """ Extend the base Trainer for axolotl helpers """ @@ -404,68 +475,6 @@ class AxolotlTrainer(Trainer): return self.optimizer - def create_scheduler( - self, num_training_steps: int, optimizer: torch.optim.Optimizer = None - ): - """ - Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or - passed as an argument. - - Args: - num_training_steps (int): The number of training steps to do. - optimizer (torch.optim.Optimizer): The training optimizer - """ - use_cosine_quadratic = ( - self.args.lr_scheduler_type == "cosine" - and self.args.lr_quadratic_warmup is True - ) - - use_cosine_min_lr = ( - self.args.lr_scheduler_type == "cosine" - and self.args.cosine_min_lr_ratio is not None - ) - - # fmt: off - if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition - # fmt: on - if use_cosine_quadratic: - if use_cosine_min_lr: - LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") - - self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - ) - elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr: - assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" - assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0" - self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - min_lr_ratio=self.args.cosine_min_lr_ratio, - constant_lr_ratio=self.args.cosine_constant_lr_ratio, - ) - elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: - assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" - self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - min_lr_ratio=self.args.cosine_min_lr_ratio, - ) - else: - return super().create_scheduler(num_training_steps, optimizer) - else: - if use_cosine_quadratic: - LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") - - if use_cosine_min_lr: - LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") - - return self.lr_scheduler - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.args.sample_packing and not self.args.pretraining: if self.args.multipack_real_batches: @@ -830,6 +839,14 @@ class AxolotlTrainer(Trainer): for key, value in metrics.items(): self._stored_metrics[train_eval][key].append(value) + def _save_checkpoint(self, model, trial, metrics=None): + # make sure the checkpoint dir exists, since trainer is flakey + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + os.makedirs(output_dir, exist_ok=True) + return super()._save_checkpoint(model, trial, metrics=metrics) + class AxolotlMambaTrainer(AxolotlTrainer): """ @@ -929,7 +946,7 @@ class ReLoRATrainer(AxolotlTrainer): return self.lr_scheduler -class AxolotlDPOTrainer(DPOTrainer): +class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): """ Extend the base DPOTrainer for axolotl helpers """ @@ -990,7 +1007,7 @@ class AxolotlDPOTrainer(DPOTrainer): return res -class AxolotlORPOTrainer(ORPOTrainer): +class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): """ Extend the base ORPOTrainer for axolotl helpers """ @@ -998,7 +1015,7 @@ class AxolotlORPOTrainer(ORPOTrainer): tag_names = ["axolotl", "orpo"] -class AxolotlKTOTrainer(KTOTrainer): +class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): """ Extend the base KTOTrainer for axolotl helpers """ @@ -1006,7 +1023,7 @@ class AxolotlKTOTrainer(KTOTrainer): tag_names = ["axolotl", "kto"] -class AxolotlCPOTrainer(CPOTrainer): +class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): """ Extend the base CPOTrainer for axolotl helpers """ @@ -1750,6 +1767,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.rl == "simpo": training_args_cls = AxolotlCPOConfig training_args_kwargs["loss_type"] = "simpo" + training_args_kwargs["max_length"] = self.cfg.sequence_len training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma if self.cfg.cpo_alpha is not None: training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py index 4f2f14098..e0e5eb129 100644 --- a/src/axolotl/prompt_strategies/dpo/chat_template.py +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -62,7 +62,7 @@ def default( tokenize=False, ) chosen_strip_index = result["chosen"].find(chosen["content"]) - result["chosen"] = result["chosen"][chosen_strip_index:] + result["chosen"] = result["chosen"][chosen_strip_index:].rstrip() result["rejected"] = tokenizer.apply_chat_template( [rejected], @@ -71,7 +71,7 @@ def default( tokenize=False, ) rejected_strip_index = result["rejected"].find(rejected["content"]) - result["rejected"] = result["rejected"][rejected_strip_index:] + result["rejected"] = result["rejected"][rejected_strip_index:].rstrip() return result diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 5ba5aed56..b8890d4f7 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -212,26 +212,23 @@ def train( elif cfg.deepspeed and is_deepspeed_zero3_enabled(): # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading trainer.accelerator.wait_for_everyone() - unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped) + trainer.save_model(cfg.output_dir) # the trainer saved a model.safetensors file in the output directory, - # but it is a proxy model and should be deleted - if os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")): + # but it is most likely a proxy model and if so, should be deleted + maybe_proxy = os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")) + maybe_sharded = os.path.exists( + os.path.join(cfg.output_dir, "model.safetensors.index.json") + ) + + if maybe_proxy and maybe_sharded: LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}") LOG.info("This is a proxy model and should be deleted") - os.remove(os.path.join(cfg.output_dir, "model.safetensors")) + try: + os.remove(os.path.join(cfg.output_dir, "model.safetensors")) + except FileNotFoundError: + pass - # Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if - # `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or - # `zero3_save_16bit_model` is True in DeepSpeed Plugin. - # For Zero Stages 1 and 2, models are saved as usual in the output directory. - # The model name saved is `pytorch_model.bin` - unwrapped_model.save_pretrained( - cfg.output_dir, - is_main_process=trainer.accelerator.is_main_process, - save_function=trainer.accelerator.save, - state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped), - ) elif cfg.local_rank == 0: if cfg.flash_optimum and BetterTransformer: model = BetterTransformer.reverse(model) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index bbea1987f..2e923057d 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -42,7 +42,7 @@ from axolotl.prompters import ( from axolotl.utils.data.pretraining import wrap_pretraining_dataset from axolotl.utils.data.utils import md5 from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_main_process, zero_first +from axolotl.utils.distributed import is_local_main_process, zero_first from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, @@ -54,7 +54,7 @@ LOG = logging.getLogger("axolotl") def prepare_dataset(cfg, tokenizer): prompters = [] if not cfg.pretraining_dataset: - with zero_first(is_main_process()): + with zero_first(is_local_main_process()): if cfg.test_datasets: train_dataset, _, prompters = load_prepare_datasets( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train" @@ -170,6 +170,7 @@ def load_tokenized_prepared_datasets( # pylint: disable=duplicate-code if dataset: + # This is for the case where we already loaded a pretokenized dataset from the hub ... elif ( cfg.dataset_prepared_path @@ -198,6 +199,8 @@ def load_tokenized_prepared_datasets( def for_d_in_datasets(dataset_configs): for dataset in dataset_configs: if dataset.name and isinstance(dataset.name, list): + # load_dataset doesn't properly handle multiple named configurations + # at the same time for a given dataset for name in dataset.name: yield DictDefault({**dataset, "name": name}) else: @@ -208,6 +211,8 @@ def load_tokenized_prepared_datasets( ds: Optional[Union[Dataset, DatasetDict]] = None ds_from_hub = False try: + # this is just a basic check to see if the path is a + # valid HF dataset that's loadable load_dataset( config_dataset.path, name=config_dataset.name, diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index ecb1bcc9e..4444a20c9 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -44,6 +44,10 @@ def is_main_process(): return dist.get_rank() == 0 +def is_local_main_process(): + return PartialState().is_main_process + + def get_world_size(): return int(os.getenv("WORLD_SIZE", "1")) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 436b31fef..8a50631ef 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -29,6 +29,7 @@ from transformers import ( # noqa: F401 AutoConfig, AutoModelForCausalLM, AutoTokenizer, + AwqConfig, BitsAndBytesConfig, GPTQConfig, PreTrainedModel, @@ -36,6 +37,7 @@ from transformers import ( # noqa: F401 ) from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from axolotl.common.architectures import MOE_ARCH_BLOCK from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, @@ -510,7 +512,25 @@ def load_model( model_kwargs["quantization_config"] = GPTQConfig( **model_config.quantization_config ) - if cfg.adapter == "qlora" and cfg.load_in_4bit: + if ( + cfg.adapter in ["qlora", "lora"] + and hasattr(model_config, "quantization_config") + and model_config.quantization_config["quant_method"] + in ["gptq", "awq", "bitsandbytes"] + ): + if model_config.quantization_config["quant_method"] == "gptq": + model_kwargs["quantization_config"] = GPTQConfig( + **model_config.quantization_config + ) + elif model_config.quantization_config["quant_method"] == "awq": + model_kwargs["quantization_config"] = AwqConfig( + **model_config.quantization_config + ) + elif model_config.quantization_config["quant_method"] == "bitsandbytes": + model_kwargs["quantization_config"] = BitsAndBytesConfig( + **model_config.quantization_config + ) + elif cfg.adapter == "qlora" and cfg.load_in_4bit: bnb_config = { "load_in_4bit": True, "llm_int8_threshold": 6.0, @@ -785,12 +805,14 @@ def load_model( set_z3_leaf_modules, ) - if cfg.model_config_type == "mixtral": - moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock") - set_z3_leaf_modules(model, [moe_block]) - elif cfg.model_config_type == "dbrx": - moe_block = get_module_class_from_name(model, "DbrxFFN") - set_z3_leaf_modules(model, [moe_block]) + if cfg.model_config_type in MOE_ARCH_BLOCK: + set_z3_leaf_modules( + model, + [ + get_module_class_from_name(model, module_name) + for module_name in MOE_ARCH_BLOCK[cfg.model_config_type] + ], + ) if cfg.model_config_type == "qwen" and cfg.adapter == "lora": # Qwen doesn't play nicely with LoRA if this is enabled @@ -804,6 +826,9 @@ def load_model( # make sure everything is in the same dtype skip_prepare_model_for_kbit_training = True + if is_deepspeed_zero3_enabled(): + skip_prepare_model_for_kbit_training = True + if cfg.adapter in ["lora", "qlora"]: if cfg.gradient_checkpointing: model.gradient_checkpointing_enable( @@ -838,6 +863,9 @@ def load_model( else: model, lora_config = load_adapter(model, cfg, cfg.adapter) + if is_deepspeed_zero3_enabled(): + skip_move_to_device = True + if ( cfg.ddp and not load_in_8bit diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index c5a71e689..bb9624051 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -1,4 +1,5 @@ """Module containing the Trainer class and related functions""" +import json import math import os import random @@ -389,6 +390,19 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): return total_num_steps +def setup_deepspeed_env(cfg, stage=None): + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed + if cfg.bf16: + os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16" + elif cfg.fp16: + os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" + if stage: + os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage) + if stage == 3: + os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true" + + def setup_fsdp_envs(cfg): os.environ["ACCELERATE_USE_FSDP"] = "true" if cfg.fsdp_config.fsdp_activation_checkpointing: @@ -415,8 +429,14 @@ def prepare_optim_env(cfg): if cfg.fsdp: setup_fsdp_envs(cfg) elif cfg.deepspeed: - os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" - os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed + stage = None + # check if the cfg.deepspeed is a file + if os.path.isfile(cfg.deepspeed): + # parse with json + with open(cfg.deepspeed, "r", encoding="utf-8") as fin: + deepspeed_config = json.load(fin) + stage = deepspeed_config.get("zero_optimization", {}).get("stage", None) + setup_deepspeed_env(cfg, stage=stage) if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True: os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"