various batch of fixes (#1785)
* various batch of fixes * more tweaks * fix autoawq requirement for torch flexibility * simplify conditionals * multi-node fixes wip * bump transformers and include 405b qlora+fsdp yaml
This commit is contained in:
62
examples/llama-3/qlora-fsdp-405b.yaml
Normal file
62
examples/llama-3/qlora-fsdp-405b.yaml
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
base_model: meta-llama/Meta-Llama-3.1-405B
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out/qlora-llama3_1-405b
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
|
||||||
|
sequence_len: 1024
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
lora_target_linear: true
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.00001
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: true
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: true
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|finetune_right_pad_id|>
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.11.1
|
peft==0.11.1
|
||||||
transformers==4.43.1
|
transformers==4.43.3
|
||||||
tokenizers==0.19.1
|
tokenizers==0.19.1
|
||||||
bitsandbytes==0.43.1
|
bitsandbytes==0.43.1
|
||||||
accelerate==0.32.0
|
accelerate==0.32.0
|
||||||
@@ -32,6 +32,7 @@ fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e59
|
|||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
tensorboard
|
tensorboard
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
|
autoawq>=0.2.5
|
||||||
|
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
CLI to run training on a model
|
CLI to run training on a model
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -76,8 +77,12 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
if parsed_cli_args.download:
|
if parsed_cli_args.download:
|
||||||
model_name = parsed_cfg.base_model
|
model_name = parsed_cfg.base_model
|
||||||
with init_empty_weights():
|
with warnings.catch_warnings():
|
||||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
# there are a bunch of useless UserWarnings about
|
||||||
|
# "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model"
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
with init_empty_weights(include_buffers=True):
|
||||||
|
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
Fore.GREEN
|
Fore.GREEN
|
||||||
|
|||||||
14
src/axolotl/common/architectures.py
Normal file
14
src/axolotl/common/architectures.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Common architecture specific constants
|
||||||
|
"""
|
||||||
|
|
||||||
|
MOE_ARCH_BLOCK = {
|
||||||
|
"dbrx": "DbrxFFN",
|
||||||
|
"jamba": "JambaSparseMoeBlock",
|
||||||
|
"jetmoe": [
|
||||||
|
"JetMoeMoA",
|
||||||
|
"JetMoeMoE",
|
||||||
|
],
|
||||||
|
"mixtral": "MixtralSparseMoeBlock",
|
||||||
|
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ import importlib
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@@ -28,7 +29,7 @@ from transformers import (
|
|||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import (
|
from trl import (
|
||||||
CPOConfig,
|
CPOConfig,
|
||||||
@@ -286,7 +287,77 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
class SchedulerMixin(Trainer):
|
||||||
|
"""
|
||||||
|
Mixin class for scheduler setup in CausalTrainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: AxolotlTrainingArguments
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||||
|
passed as an argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_training_steps (int): The number of training steps to do.
|
||||||
|
optimizer (torch.optim.Optimizer): The training optimizer
|
||||||
|
"""
|
||||||
|
use_cosine_quadratic = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.lr_quadratic_warmup is True
|
||||||
|
)
|
||||||
|
|
||||||
|
use_cosine_min_lr = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.cosine_min_lr_ratio is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||||
|
# fmt: on
|
||||||
|
if use_cosine_quadratic:
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||||
|
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
)
|
||||||
|
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
|
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||||
|
)
|
||||||
|
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
else:
|
||||||
|
if use_cosine_quadratic:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||||
"""
|
"""
|
||||||
Extend the base Trainer for axolotl helpers
|
Extend the base Trainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
@@ -404,68 +475,6 @@ class AxolotlTrainer(Trainer):
|
|||||||
|
|
||||||
return self.optimizer
|
return self.optimizer
|
||||||
|
|
||||||
def create_scheduler(
|
|
||||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
|
||||||
passed as an argument.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_training_steps (int): The number of training steps to do.
|
|
||||||
optimizer (torch.optim.Optimizer): The training optimizer
|
|
||||||
"""
|
|
||||||
use_cosine_quadratic = (
|
|
||||||
self.args.lr_scheduler_type == "cosine"
|
|
||||||
and self.args.lr_quadratic_warmup is True
|
|
||||||
)
|
|
||||||
|
|
||||||
use_cosine_min_lr = (
|
|
||||||
self.args.lr_scheduler_type == "cosine"
|
|
||||||
and self.args.cosine_min_lr_ratio is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
|
||||||
# fmt: on
|
|
||||||
if use_cosine_quadratic:
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
|
||||||
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
)
|
|
||||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
|
||||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
|
||||||
)
|
|
||||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
|
||||||
else:
|
|
||||||
if use_cosine_quadratic:
|
|
||||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
|
||||||
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
|
||||||
|
|
||||||
return self.lr_scheduler
|
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
if self.args.multipack_real_batches:
|
if self.args.multipack_real_batches:
|
||||||
@@ -830,6 +839,14 @@ class AxolotlTrainer(Trainer):
|
|||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
self._stored_metrics[train_eval][key].append(value)
|
self._stored_metrics[train_eval][key].append(value)
|
||||||
|
|
||||||
|
def _save_checkpoint(self, model, trial, metrics=None):
|
||||||
|
# make sure the checkpoint dir exists, since trainer is flakey
|
||||||
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||||
|
run_dir = self._get_output_dir(trial=trial)
|
||||||
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
return super()._save_checkpoint(model, trial, metrics=metrics)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -929,7 +946,7 @@ class ReLoRATrainer(AxolotlTrainer):
|
|||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
class AxolotlDPOTrainer(DPOTrainer):
|
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base DPOTrainer for axolotl helpers
|
Extend the base DPOTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
@@ -990,7 +1007,7 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(ORPOTrainer):
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base ORPOTrainer for axolotl helpers
|
Extend the base ORPOTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
@@ -998,7 +1015,7 @@ class AxolotlORPOTrainer(ORPOTrainer):
|
|||||||
tag_names = ["axolotl", "orpo"]
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(KTOTrainer):
|
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base KTOTrainer for axolotl helpers
|
Extend the base KTOTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
@@ -1006,7 +1023,7 @@ class AxolotlKTOTrainer(KTOTrainer):
|
|||||||
tag_names = ["axolotl", "kto"]
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(CPOTrainer):
|
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base CPOTrainer for axolotl helpers
|
Extend the base CPOTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
@@ -1750,6 +1767,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rl == "simpo":
|
if self.cfg.rl == "simpo":
|
||||||
training_args_cls = AxolotlCPOConfig
|
training_args_cls = AxolotlCPOConfig
|
||||||
training_args_kwargs["loss_type"] = "simpo"
|
training_args_kwargs["loss_type"] = "simpo"
|
||||||
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
|
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
|
||||||
if self.cfg.cpo_alpha is not None:
|
if self.cfg.cpo_alpha is not None:
|
||||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ def default(
|
|||||||
tokenize=False,
|
tokenize=False,
|
||||||
)
|
)
|
||||||
chosen_strip_index = result["chosen"].find(chosen["content"])
|
chosen_strip_index = result["chosen"].find(chosen["content"])
|
||||||
result["chosen"] = result["chosen"][chosen_strip_index:]
|
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
|
||||||
|
|
||||||
result["rejected"] = tokenizer.apply_chat_template(
|
result["rejected"] = tokenizer.apply_chat_template(
|
||||||
[rejected],
|
[rejected],
|
||||||
@@ -71,7 +71,7 @@ def default(
|
|||||||
tokenize=False,
|
tokenize=False,
|
||||||
)
|
)
|
||||||
rejected_strip_index = result["rejected"].find(rejected["content"])
|
rejected_strip_index = result["rejected"].find(rejected["content"])
|
||||||
result["rejected"] = result["rejected"][rejected_strip_index:]
|
result["rejected"] = result["rejected"][rejected_strip_index:].rstrip()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -212,26 +212,23 @@ def train(
|
|||||||
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
||||||
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
||||||
trainer.accelerator.wait_for_everyone()
|
trainer.accelerator.wait_for_everyone()
|
||||||
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
|
trainer.save_model(cfg.output_dir)
|
||||||
|
|
||||||
# the trainer saved a model.safetensors file in the output directory,
|
# the trainer saved a model.safetensors file in the output directory,
|
||||||
# but it is a proxy model and should be deleted
|
# but it is most likely a proxy model and if so, should be deleted
|
||||||
if os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")):
|
maybe_proxy = os.path.exists(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||||
|
maybe_sharded = os.path.exists(
|
||||||
|
os.path.join(cfg.output_dir, "model.safetensors.index.json")
|
||||||
|
)
|
||||||
|
|
||||||
|
if maybe_proxy and maybe_sharded:
|
||||||
LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}")
|
LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}")
|
||||||
LOG.info("This is a proxy model and should be deleted")
|
LOG.info("This is a proxy model and should be deleted")
|
||||||
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
try:
|
||||||
|
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
|
|
||||||
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
|
|
||||||
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
|
|
||||||
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
|
|
||||||
# The model name saved is `pytorch_model.bin`
|
|
||||||
unwrapped_model.save_pretrained(
|
|
||||||
cfg.output_dir,
|
|
||||||
is_main_process=trainer.accelerator.is_main_process,
|
|
||||||
save_function=trainer.accelerator.save,
|
|
||||||
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
|
|
||||||
)
|
|
||||||
elif cfg.local_rank == 0:
|
elif cfg.local_rank == 0:
|
||||||
if cfg.flash_optimum and BetterTransformer:
|
if cfg.flash_optimum and BetterTransformer:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ from axolotl.prompters import (
|
|||||||
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||||
from axolotl.utils.data.utils import md5
|
from axolotl.utils.data.utils import md5
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import is_local_main_process, zero_first
|
||||||
from axolotl.utils.trainer import (
|
from axolotl.utils.trainer import (
|
||||||
calculate_total_num_steps,
|
calculate_total_num_steps,
|
||||||
process_datasets_for_packing,
|
process_datasets_for_packing,
|
||||||
@@ -54,7 +54,7 @@ LOG = logging.getLogger("axolotl")
|
|||||||
def prepare_dataset(cfg, tokenizer):
|
def prepare_dataset(cfg, tokenizer):
|
||||||
prompters = []
|
prompters = []
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_local_main_process()):
|
||||||
if cfg.test_datasets:
|
if cfg.test_datasets:
|
||||||
train_dataset, _, prompters = load_prepare_datasets(
|
train_dataset, _, prompters = load_prepare_datasets(
|
||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
|
||||||
@@ -170,6 +170,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
if dataset:
|
if dataset:
|
||||||
|
# This is for the case where we already loaded a pretokenized dataset from the hub
|
||||||
...
|
...
|
||||||
elif (
|
elif (
|
||||||
cfg.dataset_prepared_path
|
cfg.dataset_prepared_path
|
||||||
@@ -198,6 +199,8 @@ def load_tokenized_prepared_datasets(
|
|||||||
def for_d_in_datasets(dataset_configs):
|
def for_d_in_datasets(dataset_configs):
|
||||||
for dataset in dataset_configs:
|
for dataset in dataset_configs:
|
||||||
if dataset.name and isinstance(dataset.name, list):
|
if dataset.name and isinstance(dataset.name, list):
|
||||||
|
# load_dataset doesn't properly handle multiple named configurations
|
||||||
|
# at the same time for a given dataset
|
||||||
for name in dataset.name:
|
for name in dataset.name:
|
||||||
yield DictDefault({**dataset, "name": name})
|
yield DictDefault({**dataset, "name": name})
|
||||||
else:
|
else:
|
||||||
@@ -208,6 +211,8 @@ def load_tokenized_prepared_datasets(
|
|||||||
ds: Optional[Union[Dataset, DatasetDict]] = None
|
ds: Optional[Union[Dataset, DatasetDict]] = None
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
try:
|
try:
|
||||||
|
# this is just a basic check to see if the path is a
|
||||||
|
# valid HF dataset that's loadable
|
||||||
load_dataset(
|
load_dataset(
|
||||||
config_dataset.path,
|
config_dataset.path,
|
||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
|
|||||||
@@ -44,6 +44,10 @@ def is_main_process():
|
|||||||
return dist.get_rank() == 0
|
return dist.get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def is_local_main_process():
|
||||||
|
return PartialState().is_main_process
|
||||||
|
|
||||||
|
|
||||||
def get_world_size():
|
def get_world_size():
|
||||||
return int(os.getenv("WORLD_SIZE", "1"))
|
return int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from transformers import ( # noqa: F401
|
|||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
AwqConfig,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
GPTQConfig,
|
GPTQConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
@@ -36,6 +37,7 @@ from transformers import ( # noqa: F401
|
|||||||
)
|
)
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
|
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.monkeypatch.multipack import (
|
from axolotl.monkeypatch.multipack import (
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||||
@@ -510,7 +512,25 @@ def load_model(
|
|||||||
model_kwargs["quantization_config"] = GPTQConfig(
|
model_kwargs["quantization_config"] = GPTQConfig(
|
||||||
**model_config.quantization_config
|
**model_config.quantization_config
|
||||||
)
|
)
|
||||||
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
if (
|
||||||
|
cfg.adapter in ["qlora", "lora"]
|
||||||
|
and hasattr(model_config, "quantization_config")
|
||||||
|
and model_config.quantization_config["quant_method"]
|
||||||
|
in ["gptq", "awq", "bitsandbytes"]
|
||||||
|
):
|
||||||
|
if model_config.quantization_config["quant_method"] == "gptq":
|
||||||
|
model_kwargs["quantization_config"] = GPTQConfig(
|
||||||
|
**model_config.quantization_config
|
||||||
|
)
|
||||||
|
elif model_config.quantization_config["quant_method"] == "awq":
|
||||||
|
model_kwargs["quantization_config"] = AwqConfig(
|
||||||
|
**model_config.quantization_config
|
||||||
|
)
|
||||||
|
elif model_config.quantization_config["quant_method"] == "bitsandbytes":
|
||||||
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
|
**model_config.quantization_config
|
||||||
|
)
|
||||||
|
elif cfg.adapter == "qlora" and cfg.load_in_4bit:
|
||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"llm_int8_threshold": 6.0,
|
"llm_int8_threshold": 6.0,
|
||||||
@@ -785,12 +805,14 @@ def load_model(
|
|||||||
set_z3_leaf_modules,
|
set_z3_leaf_modules,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.model_config_type == "mixtral":
|
if cfg.model_config_type in MOE_ARCH_BLOCK:
|
||||||
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
|
set_z3_leaf_modules(
|
||||||
set_z3_leaf_modules(model, [moe_block])
|
model,
|
||||||
elif cfg.model_config_type == "dbrx":
|
[
|
||||||
moe_block = get_module_class_from_name(model, "DbrxFFN")
|
get_module_class_from_name(model, module_name)
|
||||||
set_z3_leaf_modules(model, [moe_block])
|
for module_name in MOE_ARCH_BLOCK[cfg.model_config_type]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
||||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||||
@@ -804,6 +826,9 @@ def load_model(
|
|||||||
# make sure everything is in the same dtype
|
# make sure everything is in the same dtype
|
||||||
skip_prepare_model_for_kbit_training = True
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
if cfg.adapter in ["lora", "qlora"]:
|
if cfg.adapter in ["lora", "qlora"]:
|
||||||
if cfg.gradient_checkpointing:
|
if cfg.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable(
|
model.gradient_checkpointing_enable(
|
||||||
@@ -838,6 +863,9 @@ def load_model(
|
|||||||
else:
|
else:
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
|
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
skip_move_to_device = True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cfg.ddp
|
cfg.ddp
|
||||||
and not load_in_8bit
|
and not load_in_8bit
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Module containing the Trainer class and related functions"""
|
"""Module containing the Trainer class and related functions"""
|
||||||
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@@ -389,6 +390,19 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
return total_num_steps
|
return total_num_steps
|
||||||
|
|
||||||
|
|
||||||
|
def setup_deepspeed_env(cfg, stage=None):
|
||||||
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||||
|
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||||
|
if cfg.bf16:
|
||||||
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
||||||
|
elif cfg.fp16:
|
||||||
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
|
||||||
|
if stage:
|
||||||
|
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
||||||
|
if stage == 3:
|
||||||
|
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
|
||||||
|
|
||||||
|
|
||||||
def setup_fsdp_envs(cfg):
|
def setup_fsdp_envs(cfg):
|
||||||
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||||
if cfg.fsdp_config.fsdp_activation_checkpointing:
|
if cfg.fsdp_config.fsdp_activation_checkpointing:
|
||||||
@@ -415,8 +429,14 @@ def prepare_optim_env(cfg):
|
|||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
setup_fsdp_envs(cfg)
|
setup_fsdp_envs(cfg)
|
||||||
elif cfg.deepspeed:
|
elif cfg.deepspeed:
|
||||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
stage = None
|
||||||
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
# check if the cfg.deepspeed is a file
|
||||||
|
if os.path.isfile(cfg.deepspeed):
|
||||||
|
# parse with json
|
||||||
|
with open(cfg.deepspeed, "r", encoding="utf-8") as fin:
|
||||||
|
deepspeed_config = json.load(fin)
|
||||||
|
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
|
||||||
|
setup_deepspeed_env(cfg, stage=stage)
|
||||||
|
|
||||||
if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
|
if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
|
||||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
||||||
|
|||||||
Reference in New Issue
Block a user