diff --git a/TODO.md b/TODO.md new file mode 100644 index 000000000..2002bbbaf --- /dev/null +++ b/TODO.md @@ -0,0 +1,10 @@ +# todo list + +- [] Validation of parameters for combinations that won't work + + + +## things that are known not to work + +- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203 +- adamw_bnb_8bit doesn't play well with FSDP offload diff --git a/ds_config.json b/ds_config.json index ffd6f2075..65955377c 100644 --- a/ds_config.json +++ b/ds_config.json @@ -10,21 +10,42 @@ "hysteresis": 2, "min_loss_scale": 1 }, - "scheduler": { - "type": "OneCycle", + "optimizer": { + "type": "Adam", "params": { - "cycle_min_lr": 1e-7, - "cycle_max_lr": 1e-4 + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto" } }, "zero_optimization": { "stage": 2, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "offload_param": { + "device": "cpu", + "pin_memory": true + }, "overlap_comm": true, "allgather_partitions": true, "allgather_bucket_size": 5e8, "contiguous_gradients": true, "reduce_bucket_size": "auto", "reduce_scatter": true, + "stage3_max_live_parameters": 0, + "stage3_max_reuse_distance": 0, "stage3_gather_16bit_weights_on_model_save": true }, "gradient_accumulation_steps": "auto", diff --git a/scripts/finetune.py b/scripts/finetune.py index 858f33f9a..a8cfe2a03 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -1,5 +1,7 @@ +import importlib import logging import os +import pathlib import random import signal import sys @@ -11,6 +13,8 @@ import yaml from attrdict import AttrDefault # add src to the pythonpath so we don't need to pip install this +from axolotl.utils.tokenization import check_dataset_labels + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") sys.path.insert(0, src_dir) @@ -42,48 +46,20 @@ def choose_device(cfg): cfg.device_map = {"": cfg.device} -def check_dataset_labels(dataset, tokenizer): - from termcolor import colored - - # the dataset is already shuffled, so let's just check the first 5 elements - for idx in range(5): - # Get the input_ids, labels, and attention_mask from the dataset - input_ids = dataset[idx]["input_ids"] - labels = dataset[idx]["labels"] - attention_mask = dataset[idx]["attention_mask"] - - # You can compare the input_ids and labels element-wise - # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0 - colored_tokens = [] - for i, (input_id, label_id, mask) in enumerate( - zip(input_ids, labels, attention_mask) - ): - decoded_input_token = tokenizer.decode(input_id) - # Choose the color based on whether the label has the ignore value or not - color = ( - "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") - ) - colored_token = colored(decoded_input_token, color) + colored( - f"({label_id}, {mask})", "white" - ) - colored_tokens.append(colored_token) - - logging.info(" ".join(colored_tokens)) - logging.info("\n\n\n") - - -def do_inference(cfg, model, tokenizer): +def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): tokenizer.add_special_tokens({"unk_token": ""}) tokenizer.add_special_tokens({"bos_token": ""}) tokenizer.add_special_tokens({"eos_token": ""}) - from axolotl.prompters import ReflectAlpacaPrompter + prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter) while True: - instruction = str(input("Give me an instruction: ")) + # support for multiline inputs + print("Give me an instruction (Ctrl + D to finish): ") + instruction = pathlib.Path("/proc/self/fd/0").read_text() if not instruction: return - prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction) + prompt = prompter_module().build_prompt(instruction=instruction) batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) model.eval() @@ -174,8 +150,8 @@ def train( cfg.bf16 = False # Load the model and tokenizer - logging.info("loading model, tokenizer, and lora_config...") - model, tokenizer, lora_config = load_model( + logging.info("loading model, tokenizer, and peft_config...") + model, tokenizer, peft_config = load_model( cfg.base_model, cfg.base_model_config, cfg.model_type, @@ -190,6 +166,10 @@ def train( do_inference(cfg, model, tokenizer) return + if "shard" in kwargs: + model.save_pretrained(cfg.output_dir) + return + train_dataset, eval_dataset = load_prepare_datasets( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) @@ -199,8 +179,9 @@ def train( return if cfg.debug: + logging.info("check_dataset_labels...") check_dataset_labels( - train_dataset.select([random.randrange(0, len(train_dataset) - 1)]), + train_dataset.select([random.randrange(0, len(train_dataset) - 1) for i in range(5)]), tokenizer, ) @@ -213,9 +194,9 @@ def train( model = torch.compile(model) # go ahead and presave, so we have the adapter config available to inspect - if lora_config: + if peft_config: logging.info(f"Pre-saving adapter config to {cfg.output_dir}") - lora_config.save_pretrained(cfg.output_dir) + peft_config.save_pretrained(cfg.output_dir) # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model if cfg.local_rank == 0: @@ -234,12 +215,11 @@ def train( logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}") trainer.train(resume_from_checkpoint=resume_from_checkpoint) - if cfg.local_rank == 0: - # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading - logging.info( - f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}" - ) - model.save_pretrained(cfg.output_dir) + logging.info( + f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}" + ) + # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading + trainer.save_model(cfg.output_dir) if __name__ == "__main__": diff --git a/scripts/setup-runpod.sh b/scripts/setup-runpod.sh index cc1212dc3..660df086f 100644 --- a/scripts/setup-runpod.sh +++ b/scripts/setup-runpod.sh @@ -26,6 +26,15 @@ if [ -z "${TORCH_CUDA_ARCH_LIST}" ]; then # only set this if not set yet export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" fi +# install flash-attn and deepspeed from pre-built wheels for this specific container b/c these take forever to install +mkdir -p /workspace/wheels +cd /workspace/wheels +curl -L -O https://github.com/winglian/axolotl/raw/wheels/wheels/deepspeed-0.9.2%2B7ddc3b01-cp38-cp38-linux_x86_64.whl +curl -L -O https://github.com/winglian/axolotl/raw/wheels/wheels/flash_attn-1.0.4-cp38-cp38-linux_x86_64.whl +pip install deepspeed-0.9.2%2B7ddc3b01-cp38-cp38-linux_x86_64.whl +pip install flash_attn-1.0.4-cp38-cp38-linux_x86_64.whl +pip install "peft @ git+https://github.com/huggingface/peft.git@main" --force-reinstall --no-dependencies + cd /workspace/ git clone https://github.com/winglian/axolotl.git cd axolotl diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 903ee4385..c2acf60c3 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -127,7 +127,7 @@ conv_vicuna_v1_1 = Conversation( class ShareGPTPrompter: - def build_prompt(self, source, tokenizer): + def build_prompt(self, source, tokenizer, sequence_len=2048): # ignore the system prompt if provided if source[0]["from"] == "system": source.pop(0) @@ -157,13 +157,14 @@ class ShareGPTPrompter: role = roles[sentence["from"]] assert role == conv.roles[j % 2] conv.append_message(role, sentence["value"]) + # TODO, this concatenates everything, but doesn't seem to properly add the eos_token_id, as the eos_token gets split up conversation = conv.get_prompt() # Tokenize conversations tokenized_result = tokenizer( conversation, truncation=True, - max_length=2048, # FIXME + max_length=sequence_len, # FIXME padding=False, return_tensors=None, ) @@ -173,7 +174,9 @@ class ShareGPTPrompter: sep = conv.sep + conv.roles[1] + ": " rounds = conversation.split(conv.sep2) + rounds = [r + conv.sep2 for r in rounds] cur_len = 1 + target[0] = IGNORE_TOKEN_ID # mask out the bos for i, rou in enumerate(rounds): if rou == "": break @@ -182,19 +185,27 @@ class ShareGPTPrompter: if len(parts) != 2: break parts[0] += sep - round_len = len(tokenizer(rou)["input_ids"]) - instruction_len = len(tokenizer(parts[0])["input_ids"]) - 2 + round_len = len(tokenizer(rou)["input_ids"]) - 1 # -1 ignores the bos_token generated for this + # we have to strip the initial part, any dangling whitespace creates an additional ghost token + instruction_len = len(tokenizer(parts[0].strip())["input_ids"]) - 1 # -1 ignores the bos_token generated for this target[cur_len : cur_len + instruction_len] = [ IGNORE_TOKEN_ID ] * instruction_len cur_len += round_len - target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len) + if cur_len >= sequence_len: + break + + # Fix: Truncate the target to have the same length as input_ids + target = target[:len(tokenized_result["input_ids"])] + # target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len) + attention_mask = [ 1 if x != tokenizer.pad_token_id else 0 for x in tokenized_result["input_ids"] ] + # TODO truncate len to sequence_len return dict( input_ids=tokenized_result["input_ids"], labels=target, diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index a14f89b19..1fc47a87f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -53,7 +53,7 @@ def load_model( logging.info("patching with xformers attention") hijack_llama_attention() - torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,) + torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32 try: if cfg.load_4bit: from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import ( @@ -101,30 +101,23 @@ def load_model( ) load_in_8bit = False elif is_llama_derived_model and "LlamaForCausalLM" in globals(): - if not cfg.load_in_8bit: - model = LlamaForCausalLM.from_pretrained( - base_model, - device_map=cfg.device_map, - ) - else: - model = LlamaForCausalLM.from_pretrained( - base_model, - load_in_8bit=cfg.load_in_8bit, - torch_dtype=torch_dtype, - device_map=cfg.device_map, - ) - + model = LlamaForCausalLM.from_pretrained( + base_model, + load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, + torch_dtype=torch_dtype, + device_map=cfg.device_map, + ) elif model_type: model = getattr(transformers, model_type).from_pretrained( base_model, - load_in_8bit=cfg.load_in_8bit, + load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, ) else: model = AutoModelForCausalLM.from_pretrained( base_model, - load_in_8bit=cfg.load_in_8bit, + load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, ) @@ -135,7 +128,7 @@ def load_model( logging.exception(e) model = AutoModelForCausalLM.from_pretrained( base_model, - load_in_8bit=cfg.load_in_8bit, + load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, ) @@ -147,7 +140,7 @@ def load_model( else: tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model) except: - tokenizer = AutoTokenizer.from_pretrained(base_model) + tokenizer = AutoTokenizer.from_pretrained(base_model_config) logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") @@ -161,12 +154,12 @@ def load_model( tokenizer.add_special_tokens({"pad_token": "[PAD]"}) os.environ["TOKENIZERS_PARALLELISM"] = "false" - if cfg.special_tokens: - for k, v in cfg.special_tokens.items(): - setattr(tokenizer, k, v) + if cfg.tokens: + for k, v in cfg.tokens.items(): + tokenizer.add_special_tokens({k: v}) - if load_in_8bit and not cfg.load_4bit: - logging.info("converting model w/ prepare_model_for_int8_training") + if cfg.adapter and load_in_8bit and not cfg.load_4bit: + logging.info("converting PEFT model w/ prepare_model_for_int8_training") model = prepare_model_for_int8_training(model) model, lora_config = load_adapter(model, cfg, adapter) @@ -186,6 +179,11 @@ def load_model( m.scales = m.scales.half() m.bias = m.bias.half() + if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1: + model.is_parallelizable = True + model.model_parallel = True + + # TODO resume_from_checkpoint handling return model, tokenizer, lora_config @@ -197,11 +195,41 @@ def load_adapter(model, cfg, adapter): return model, None if adapter == "lora": return load_lora(model, cfg) - # TODO support Llama-Adapter once merged into peft https://github.com/huggingface/peft/pulls + if adapter == "llama-adapter": + return load_llama_adapter(model, cfg) raise NotImplementedError(f"{adapter} peft adapter not available") +def load_llama_adapter(model, cfg): + # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + from peft import ( + AdaptionPromptConfig, + get_peft_model, + PeftModel, + ) + + peft_config = AdaptionPromptConfig( + adapter_layers=cfg.peft_adapter.layers, # layers (L) + adapter_len=cfg.peft_adapter.len, # prompt length (K) + task_type="CAUSAL_LM", + ) + + if cfg.peft_model_dir: + model = PeftModel.from_pretrained( + model, + cfg.lora_model_dir, + device_map=cfg.device_map, + torch_dtype=torch.float16, + ) + else: + model = get_peft_model(model, peft_config) + + model.print_trainable_parameters() + + return model, peft_config + + def load_lora(model, cfg): # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] @@ -213,27 +241,26 @@ def load_lora(model, cfg): lora_config = None - if cfg.adapter == "lora": - lora_config = LoraConfig( - r=cfg.lora_r, - lora_alpha=cfg.lora_alpha, - target_modules=cfg.lora_target_modules, - lora_dropout=cfg.lora_dropout, - fan_in_fan_out=cfg.lora_fan_in_fan_out, - bias="none", - task_type="CAUSAL_LM", + lora_config = LoraConfig( + r=cfg.lora_r, + lora_alpha=cfg.lora_alpha, + target_modules=cfg.lora_target_modules, + lora_dropout=cfg.lora_dropout, + fan_in_fan_out=cfg.lora_fan_in_fan_out, + bias="none", + task_type="CAUSAL_LM", + ) + + if cfg.lora_model_dir: + model = PeftModel.from_pretrained( + model, + cfg.lora_model_dir, + device_map=cfg.device_map, + torch_dtype=torch.float16, ) + else: + model = get_peft_model(model, lora_config) - if cfg.lora_model_dir: - model = PeftModel.from_pretrained( - model, - cfg.lora_model_dir, - device_map=cfg.device_map, - torch_dtype=torch.float16, - ) - else: - model = get_peft_model(model, lora_config) - - model.print_trainable_parameters() + model.print_trainable_parameters() return model, lora_config diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py new file mode 100644 index 000000000..72916f037 --- /dev/null +++ b/src/axolotl/utils/schedulers.py @@ -0,0 +1,33 @@ +from torch.optim.lr_scheduler import LRScheduler + + +class InterpolatingLogScheduler(LRScheduler): + def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1): + """A scheduler that interpolates learning rates in a logarithmic fashion + + Args: + - optimizer: pytorch optimizer + - num_steps: int, the number of steps over which to increase from the min_lr to the max_lr + - min_lr: float, the minimum learning rate + - max_lr: float, the maximum learning rate + + Usage: + fc = nn.Linear(1,1) + optimizer = optim.Adam(fc.parameters()) + lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4) + """ + self.num_steps = num_steps + self.min_lr = min_lr + self.max_lr = max_lr + self.q = (max_lr / min_lr) ** (1 / (num_steps - 1)) + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch <= 0: + lrs = [self.min_lr for base_lr in self.base_lrs] + elif self.last_epoch < self.num_steps: + lrs = [self.min_lr * (self.q ** (self.last_epoch - 1)) for base_lr in self.base_lrs] + else: + lrs = [self.max_lr for base_lr in self.base_lrs] + + return lrs diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py new file mode 100644 index 000000000..b9ffb1e1b --- /dev/null +++ b/src/axolotl/utils/tokenization.py @@ -0,0 +1,33 @@ +from termcolor import colored +import logging + +def check_dataset_labels(dataset, tokenizer): + # the dataset is already shuffled, so let's just check the first 5 elements + for idx in range(5): + check_example_labels(dataset[idx], tokenizer) + + +def check_example_labels(example, tokenizer): + # Get the input_ids, labels, and attention_mask from the dataset + input_ids = example["input_ids"] + labels = example["labels"] + attention_mask =example["attention_mask"] + + # You can compare the input_ids and labels element-wise + # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0 + colored_tokens = [] + for i, (input_id, label_id, mask) in enumerate( + zip(input_ids, labels, attention_mask) + ): + decoded_input_token = tokenizer.decode(input_id) + # Choose the color based on whether the label has the ignore value or not + color = ( + "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") + ) + colored_token = colored(decoded_input_token, color) + colored( + f"({label_id}, {mask}, {input_id})", "white" + ) + colored_tokens.append(colored_token) + + logging.info(" ".join(colored_tokens)) + logging.info("\n\n\n") diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index e94ea48d6..3c6aca179 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -1,5 +1,7 @@ +import importlib import math import os +import sys from pathlib import Path import bitsandbytes as bnb @@ -10,14 +12,33 @@ from torch.optim.lr_scheduler import OneCycleLR from transformers import EarlyStoppingCallback from transformers.trainer_pt_utils import get_parameter_names +from axolotl.utils.schedulers import InterpolatingLogScheduler + def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): total_num_steps = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) - warmup_steps = cfg.warmup_steps if cfg.warmup_steps else min(int(0.03 * total_num_steps), 100) - logging_steps = cfg.logging_steps if cfg.logging_steps else max(min(int(0.005 * total_num_steps), 10), 1) - save_steps = eval_steps = cfg.save_steps if cfg.save_steps else min(int(0.05 * total_num_steps), 200) + warmup_steps = ( + cfg.warmup_steps + if cfg.warmup_steps is not None + else min(int(0.03 * total_num_steps), 100) + ) + logging_steps = ( + cfg.logging_steps + if cfg.logging_steps is not None + else max(min(int(0.005 * total_num_steps), 10), 1) + ) + save_steps = ( + cfg.save_steps + if cfg.save_steps is not None + else min(int(0.05 * total_num_steps), 200) + ) + eval_steps = ( + cfg.eval_steps + if cfg.eval_steps is not None and save_steps % cfg.eval_steps == 0 + else save_steps + ) training_arguments_kwargs = {} if cfg.bf16 == "full": @@ -29,15 +50,32 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): training_arguments_kwargs["logging_steps"] = logging_steps if cfg.gradient_checkpointing is not None: if cfg.load_4bit: - from alpaca_lora_4bit.gradient_checkpointing import apply_gradient_checkpointing - gradient_checkpointing_ratio = cfg.gradient_checkpointing_ratio if cfg.gradient_checkpointing_ratio else 1.0 - apply_gradient_checkpointing(model, checkpoint_ratio=gradient_checkpointing_ratio) - else: - training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing + from alpaca_lora_4bit.gradient_checkpointing import ( + apply_gradient_checkpointing, + ) + gradient_checkpointing_ratio = ( + cfg.gradient_checkpointing_ratio + if cfg.gradient_checkpointing_ratio + else 1.0 + ) + apply_gradient_checkpointing( + model, checkpoint_ratio=gradient_checkpointing_ratio + ) + else: + training_arguments_kwargs[ + "gradient_checkpointing" + ] = cfg.gradient_checkpointing + if cfg.fsdp: + training_arguments_kwargs["fsdp"] = cfg.fsdp + if cfg.fsdp_config: + training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config) # deepspeed - if os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" and torch.cuda.device_count() > 1: + if ( + os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" + and torch.cuda.device_count() > 1 + ): if cfg.deepspeed: training_arguments_kwargs["deepspeed"] = cfg.deepspeed else: @@ -49,6 +87,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): per_device_train_batch_size=cfg.micro_batch_size, per_device_eval_batch_size=cfg.eval_batch_size, gradient_accumulation_steps=cfg.gradient_accumulation_steps, + eval_accumulation_steps=cfg.gradient_accumulation_steps, num_train_epochs=cfg.num_epochs, learning_rate=cfg.learning_rate, evaluation_strategy="steps" if cfg.val_set_size > 0 else "no", @@ -57,31 +96,51 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): save_steps=save_steps, output_dir=cfg.output_dir, save_total_limit=3, - load_best_model_at_end=True if cfg.val_set_size > 0 and save_steps % eval_steps == 0 else False, + load_best_model_at_end=True + if cfg.val_set_size > 0 and save_steps % eval_steps == 0 + else False, ddp_find_unused_parameters=False if cfg.ddp else None, group_by_length=cfg.group_by_length, report_to="wandb" if cfg.use_wandb else None, run_name=cfg.wandb_run_id if cfg.use_wandb else None, + optim=cfg.optimizer if cfg.optimizer else None, + lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine", + weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0, **training_arguments_kwargs, ) trainer_kwargs = {} - if cfg.optimizer == "adam8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs: + if cfg.optimizer == "adamw_anyprecision": + if Path(cfg.torchdistx_path).exists(): + sys.path.append(cfg.torchdistx_path) + importlib.import_module("torchdistx") + if ( + cfg.optimizer == "adamw_bnb_8bit" + and not cfg.load_4bit + and not "deepspeed" in training_arguments_kwargs + ): decay_parameters = get_parameter_names(model, [nn.LayerNorm]) decay_parameters = [name for name in decay_parameters if "bias" not in name] optimizer_grouped_parameters = [ { - "params": [p for n, p in model.named_parameters() if n in decay_parameters], + "params": [ + p + for n, p in model.named_parameters() + if (n in decay_parameters and p.requires_grad) + ], "weight_decay": training_args.weight_decay, }, { "params": [ - p for n, p in model.named_parameters() if n not in decay_parameters + p + for n, p in model.named_parameters() + if (n not in decay_parameters and p.requires_grad) ], "weight_decay": 0.0, }, ] + optimizer = bnb.optim.Adam8bit( optimizer_grouped_parameters, betas=(training_args.adam_beta1, training_args.adam_beta2), @@ -97,8 +156,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): optimizer, cfg.learning_rate, total_steps=total_num_steps, + epochs=cfg.num_epochs, **lr_scheduler_kwargs, ) + elif cfg.lr_scheduler == "log_sweep": + lr_scheduler = InterpolatingLogScheduler( + optimizer, + cfg.warmup_steps, + cfg.log_sweep_min_lr if cfg.log_sweep_min_lr else 1e-10, + cfg.log_sweep_max_lr if cfg.log_sweep_max_lr else 10, + ) else: lr_scheduler = transformers.get_cosine_schedule_with_warmup( optimizer,