diff --git a/configs/galactica_1_3B.yml b/configs/galactica_1_3B.yml
index ed722f34e..1682849cf 100644
--- a/configs/galactica_1_3B.yml
+++ b/configs/galactica_1_3B.yml
@@ -34,7 +34,7 @@ tf32: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
-special_tokens:
+tokens:
pad_token: "[PAD]"
bos_token: ""
eos_token: ""
diff --git a/configs/llama_7B_jeopardy.yml b/configs/llama_7B_jeopardy.yml
index 1f0fbf9cf..f73bec348 100644
--- a/configs/llama_7B_jeopardy.yml
+++ b/configs/llama_7B_jeopardy.yml
@@ -51,7 +51,7 @@ deepspeed:
weight_decay: 0.0001
fsdp:
fsdp_config:
-special_tokens:
+tokens:
pad_token: "[PAD]"
bos_token: ""
eos_token: ""
diff --git a/configs/stability_3b.yml b/configs/stability_3b.yml
index 080f4c753..c5f2198d8 100644
--- a/configs/stability_3b.yml
+++ b/configs/stability_3b.yml
@@ -49,7 +49,7 @@ deepspeed:
weight_decay: 0.01
fsdp:
fsdp_config:
-#special_tokens:
+#tokens:
# pad_token: "[PAD]"
# bos_token: ""
# eos_token: ""
diff --git a/examples/4bit-lora-7b/config.yml b/examples/4bit-lora-7b/config.yml
index 32cb7d680..345e0812e 100644
--- a/examples/4bit-lora-7b/config.yml
+++ b/examples/4bit-lora-7b/config.yml
@@ -55,7 +55,7 @@ deepspeed:
weight_decay: 0.0001
fsdp:
fsdp_config:
-special_tokens:
+tokens:
pad_token: "[PAD]"
bos_token: ""
eos_token: ""
diff --git a/examples/mpt-7b/config.yml b/examples/mpt-7b/config.yml
index 1323cc29b..f33452266 100644
--- a/examples/mpt-7b/config.yml
+++ b/examples/mpt-7b/config.yml
@@ -1,7 +1,6 @@
base_model: mosaicml/mpt-7b
base_model_config: mosaicml/mpt-7b
-model_type: AutoModelForCausalLM
-tokenizer_type: GPTNeoXTokenizer
+tokenizer_type: AutoTokenizer
trust_remote_code: true # required for mpt as their model class is not merged into transformers yet
load_in_8bit: false
datasets:
@@ -25,7 +24,7 @@ wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./mpt-alpaca-7b
-batch_size: 4
+batch_size: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
@@ -52,7 +51,7 @@ deepspeed:
weight_decay: 0.0001
fsdp:
fsdp_config:
-special_tokens:
+tokens:
pad_token: "<|padding|>"
bos_token: "<|endoftext|>"
eos_token: "<|endoftext|>"
diff --git a/examples/redpajama/config-3b.yml b/examples/redpajama/config-3b.yml
index 4268dd2cf..229d6615c 100644
--- a/examples/redpajama/config-3b.yml
+++ b/examples/redpajama/config-3b.yml
@@ -52,7 +52,7 @@ deepspeed:
weight_decay: 0.0001
fsdp:
fsdp_config:
-special_tokens:
+tokens:
pad_token: "<|padding|>"
bos_token: "<|endoftext|>"
eos_token: "<|endoftext|>"
diff --git a/scripts/finetune.py b/scripts/finetune.py
index 915ba1de1..d6a920f5d 100644
--- a/scripts/finetune.py
+++ b/scripts/finetune.py
@@ -191,7 +191,9 @@ def train(
if cfg.debug:
logging.info("check_dataset_labels...")
check_dataset_labels(
- train_dataset.select([random.randrange(0, len(train_dataset) - 1) for i in range(5)]),
+ train_dataset.select(
+ [random.randrange(0, len(train_dataset) - 1) for i in range(5)]
+ ),
tokenizer,
)
@@ -218,20 +220,23 @@ def train(
logging.info("Starting trainer...")
resume_from_checkpoint = cfg.resume_from_checkpoint
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
- possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")]
+ possible_checkpoints = [
+ str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
+ ]
if len(possible_checkpoints) > 0:
- sorted_paths = sorted(possible_checkpoints, key=lambda path: int(path.split('-')[-1]))
+ sorted_paths = sorted(
+ possible_checkpoints, key=lambda path: int(path.split("-")[-1])
+ )
resume_from_checkpoint = sorted_paths[-1]
- logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}")
+ logging.info(
+ f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
+ )
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
- logging.info(
- f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
- )
-
+ logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
- trainer.save_pretrained(cfg.output_dir)
+ model.save_pretrained(cfg.output_dir)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
diff --git a/setup.py b/setup.py
index a183bcda1..134e4be66 100644
--- a/setup.py
+++ b/setup.py
@@ -10,22 +10,22 @@ with open("./requirements.txt", "r") as requirements_file:
install_requires.append(r)
setup(
- name='axolotl',
- version='0.1',
+ name="axolotl",
+ version="0.1",
description="You know you're going to axolotl questions",
- package_dir={'': 'src'},
+ package_dir={"": "src"},
packages=find_packages(),
install_requires=install_requires,
extras_require={
- 'int4': [
+ "int4": [
"alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
],
- 'int4_triton': [
+ "int4_triton": [
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
],
- 'extras': [
- 'flash-attn',
- 'deepspeed',
- ]
+ "extras": [
+ "flash-attn",
+ "deepspeed",
+ ],
},
)
diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py
index deab5e438..d9acf5715 100644
--- a/src/axolotl/datasets.py
+++ b/src/axolotl/datasets.py
@@ -31,6 +31,7 @@ class TokenizedPromptDataset(IterableDataset):
except InvalidDataException:
pass
+
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""
@@ -40,6 +41,7 @@ class ConstantLengthDataset(IterableDataset):
dataset (dataset.Dataset): Dataset with text files.
seq_length (int): Length of token sequences to return.
"""
+
def __init__(
self,
tokenizer,
@@ -93,14 +95,19 @@ class ConstantLengthDataset(IterableDataset):
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
- if labels.size() == input_ids.size() and attention_mask.size() == input_ids.size():
+ if (
+ labels.size() == input_ids.size()
+ and attention_mask.size() == input_ids.size()
+ ):
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
else:
- logging.warning("dropping batch due to tensor size mismatch")
+ logging.warning(
+ "dropping batch due to tensor size mismatch"
+ )
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer_len = 0
@@ -116,11 +123,15 @@ class ConstantLengthDataset(IterableDataset):
attention_mask.append(1)
labels.append(self.concat_token_id)
- input_ids_with_concat = torch.tensor(input_ids, dtype=self.tokens_dtype)
+ input_ids_with_concat = torch.tensor(
+ input_ids, dtype=self.tokens_dtype
+ )
attention_mask_with_concat = torch.tensor(
attention_mask, dtype=self.tokens_dtype
)
- labels_with_concat = torch.tensor(labels, dtype=self.tokens_dtype)
+ labels_with_concat = torch.tensor(
+ labels, dtype=self.tokens_dtype
+ )
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py
index 167648618..00d8ecbf9 100644
--- a/src/axolotl/prompt_tokenizers.py
+++ b/src/axolotl/prompt_tokenizers.py
@@ -126,10 +126,8 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
- def parse_instruction_fields(self, prompt) -> (str):
- return (
- prompt["text"]
- )
+ def parse_instruction_fields(self, prompt) -> str:
+ return prompt["text"]
def tokenize_prompt(self, prompt):
instruction = self.parse_instruction_fields(prompt)
@@ -139,9 +137,7 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
return tokenized_full_prompt
def _build_full_prompt(self, instruction):
- return self.prompter.build_prompt(
- instruction
- )
+ return self.prompter.build_prompt(instruction)
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
@@ -149,8 +145,16 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
raise NotImplementedError
def tokenize_prompt(self, prompt):
- instruction, input, output, reflection, corrected = self.parse_instruction_fields(prompt)
- full_prompt = self._build_full_prompt(instruction, input, output, reflection, corrected)
+ (
+ instruction,
+ input,
+ output,
+ reflection,
+ corrected,
+ ) = self.parse_instruction_fields(prompt)
+ full_prompt = self._build_full_prompt(
+ instruction, input, output, reflection, corrected
+ )
tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs:
user_prompt = self.prompter.build_prompt(
diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py
index 914cbd0de..3dc5d6433 100644
--- a/src/axolotl/prompters.py
+++ b/src/axolotl/prompters.py
@@ -36,10 +36,7 @@ class JeopardyPrompter(AlpacaPrompter):
class CompletionPrompter(AlpacaPrompter):
- def build_prompt(
- self,
- instruction: str
- ) -> str:
+ def build_prompt(self, instruction: str) -> str:
return instruction
def get_response(self, output: str) -> str:
@@ -75,7 +72,9 @@ class ReflectAlpacaPrompter:
else:
res = self.prompt_no_input.format(instruction=instruction)
if output and reflection and corrected:
- label = self.agent_label.format(output=output, reflection=reflection, corrected=corrected)
+ label = self.agent_label.format(
+ output=output, reflection=reflection, corrected=corrected
+ )
res = f"{res}{label}"
return res
@@ -200,9 +199,13 @@ class ShareGPTPrompter:
if len(parts) != 2:
break
parts[0] += sep
- round_len = len(tokenizer(rou)["input_ids"]) - 1 # -1 ignores the bos_token generated for this
+ round_len = (
+ len(tokenizer(rou)["input_ids"]) - 1
+ ) # -1 ignores the bos_token generated for this
# we have to strip the initial part, any dangling whitespace creates an additional ghost token
- instruction_len = len(tokenizer(parts[0].strip())["input_ids"]) - 1 # -1 ignores the bos_token generated for this
+ instruction_len = (
+ len(tokenizer(parts[0].strip())["input_ids"]) - 1
+ ) # -1 ignores the bos_token generated for this
target[cur_len : cur_len + instruction_len] = [
IGNORE_TOKEN_ID
] * instruction_len
@@ -212,7 +215,7 @@ class ShareGPTPrompter:
break
# Fix: Truncate the target to have the same length as input_ids
- target = target[:len(tokenized_result["input_ids"])]
+ target = target[: len(tokenized_result["input_ids"])]
# target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
attention_mask = [
diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py
index aaf96bcb0..229cd9b98 100644
--- a/src/axolotl/utils/callbacks.py
+++ b/src/axolotl/utils/callbacks.py
@@ -1,8 +1,15 @@
import os
-from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
+from transformers import (
+ Seq2SeqTrainer,
+ TrainerCallback,
+ TrainingArguments,
+ TrainerState,
+ TrainerControl,
+)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+
class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
@@ -11,7 +18,9 @@ class SavePeftModelCallback(TrainerCallback):
control: TrainerControl,
**kwargs,
):
- checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
+ checkpoint_folder = os.path.join(
+ args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
+ )
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(peft_model_path)
diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py
index b217b50d7..581b48a88 100644
--- a/src/axolotl/utils/data.py
+++ b/src/axolotl/utils/data.py
@@ -2,7 +2,13 @@ import logging
from hashlib import md5
from pathlib import Path
-from datasets import load_from_disk, load_dataset, IterableDataset, Dataset, concatenate_datasets
+from datasets import (
+ load_from_disk,
+ load_dataset,
+ IterableDataset,
+ Dataset,
+ concatenate_datasets,
+)
from huggingface_hub import hf_hub_download
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
@@ -75,7 +81,9 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
else:
ds = load_dataset(d.path, streaming=True)
else:
- fp = hf_hub_download(repo_id=d.path, repo_type="dataset", filename=d.data_files)
+ fp = hf_hub_download(
+ repo_id=d.path, repo_type="dataset", filename=d.data_files
+ )
ds = load_dataset("json", data_files=fp, streaming=True, split=None)
if not ds:
raise Exception("unhandled dataset load")
@@ -140,7 +148,9 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
samples = samples + [i for i in d]
dataset = Dataset.from_list(samples).shuffle(seed=42)
if cfg.local_rank == 0:
- logging.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
+ logging.info(
+ f"Saving merged prepared dataset to disk... {prepared_ds_path}"
+ )
dataset.save_to_disk(prepared_ds_path)
if cfg.max_packed_sequence_len is not None:
@@ -153,12 +163,14 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
- logging.info(f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards")
- dataset = dataset.shard(num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx)
+ logging.info(
+ f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
+ )
+ dataset = dataset.shard(
+ num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx
+ )
- dataset = dataset.train_test_split(
- test_size=cfg.val_set_size, shuffle=False
- )
+ dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index 8c80b2621..2ca84b795 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -9,14 +9,18 @@ from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
+ AutoConfig,
)
+
try:
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
)
except:
- logging.warning("This version of transformers does not support Llama. Consider upgrading.")
+ logging.warning(
+ "This version of transformers does not support Llama. Consider upgrading."
+ )
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
@@ -40,7 +44,9 @@ def load_model(
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
tokenizer = None
- is_llama_derived_model = "llama" in base_model or (cfg.model_type and "llama" in cfg.model_type.lower())
+ is_llama_derived_model = "llama" in base_model or (
+ cfg.model_type and "llama" in cfg.model_type.lower()
+ )
if is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and inference is False:
@@ -49,11 +55,16 @@ def load_model(
logging.info("patching with flash attention")
replace_llama_attn_with_flash_attn()
elif is_llama_derived_model and cfg.xformers_attention:
- from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention
+ from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import (
+ hijack_llama_attention,
+ )
+
logging.info("patching with xformers attention")
hijack_llama_attention()
- torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
+ torch_dtype = (
+ torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
+ )
try:
if cfg.load_4bit:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
@@ -74,8 +85,12 @@ def load_model(
try:
snapshot_download_kwargs = {}
if cfg.base_model_ignore_patterns:
- snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns
- cache_model_path = Path(snapshot_download(base_model, **snapshot_download_kwargs))
+ snapshot_download_kwargs[
+ "ignore_patterns"
+ ] = cfg.base_model_ignore_patterns
+ cache_model_path = Path(
+ snapshot_download(base_model, **snapshot_download_kwargs)
+ )
files = (
list(cache_model_path.glob("*.pt"))
+ list(cache_model_path.glob("*.safetensors"))
@@ -116,8 +131,13 @@ def load_model(
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
else:
+ config = AutoConfig.from_pretrained(
+ base_model,
+ trust_remote_code=True if cfg.trust_remote_code is True else False,
+ )
model = AutoModelForCausalLM.from_pretrained(
base_model,
+ config=config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py
index 72916f037..b9b7e25be 100644
--- a/src/axolotl/utils/schedulers.py
+++ b/src/axolotl/utils/schedulers.py
@@ -26,7 +26,10 @@ class InterpolatingLogScheduler(LRScheduler):
if self.last_epoch <= 0:
lrs = [self.min_lr for base_lr in self.base_lrs]
elif self.last_epoch < self.num_steps:
- lrs = [self.min_lr * (self.q ** (self.last_epoch - 1)) for base_lr in self.base_lrs]
+ lrs = [
+ self.min_lr * (self.q ** (self.last_epoch - 1))
+ for base_lr in self.base_lrs
+ ]
else:
lrs = [self.max_lr for base_lr in self.base_lrs]
diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py
index b9ffb1e1b..f23ca8a92 100644
--- a/src/axolotl/utils/tokenization.py
+++ b/src/axolotl/utils/tokenization.py
@@ -1,6 +1,7 @@
from termcolor import colored
import logging
+
def check_dataset_labels(dataset, tokenizer):
# the dataset is already shuffled, so let's just check the first 5 elements
for idx in range(5):
@@ -11,7 +12,7 @@ def check_example_labels(example, tokenizer):
# Get the input_ids, labels, and attention_mask from the dataset
input_ids = example["input_ids"]
labels = example["labels"]
- attention_mask =example["attention_mask"]
+ attention_mask = example["attention_mask"]
# You can compare the input_ids and labels element-wise
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
@@ -21,9 +22,7 @@ def check_example_labels(example, tokenizer):
):
decoded_input_token = tokenizer.decode(input_id)
# Choose the color based on whether the label has the ignore value or not
- color = (
- "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
- )
+ color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
colored_token = colored(decoded_input_token, color) + colored(
f"({label_id}, {mask}, {input_id})", "white"
)
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index aa8c72a3c..cd9f94229 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -30,16 +30,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.logging_steps is not None
else max(min(int(0.005 * total_num_steps), 10), 1)
)
- save_steps = (
- cfg.save_steps
- if cfg.save_steps is not None
- else min(int(0.05 * total_num_steps), 200)
- )
- eval_steps = (
- cfg.eval_steps
- if cfg.eval_steps is not None and save_steps % cfg.eval_steps == 0
- else save_steps
- )
+ save_steps = cfg.save_steps
+ eval_steps = cfg.eval_steps
training_arguments_kwargs = {}
if cfg.bf16 == "full":
@@ -86,26 +78,33 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_args = transformers.TrainingArguments(
per_device_train_batch_size=cfg.micro_batch_size,
- per_device_eval_batch_size=cfg.eval_batch_size if cfg.eval_batch_size is not None else cfg.micro_batch_size,
+ per_device_eval_batch_size=cfg.eval_batch_size
+ if cfg.eval_batch_size is not None
+ else cfg.micro_batch_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate,
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
- save_strategy="steps",
+ save_strategy="steps" if save_steps else "epoch",
eval_steps=eval_steps if cfg.val_set_size > 0 else None,
save_steps=save_steps,
output_dir=cfg.output_dir,
save_total_limit=3,
load_best_model_at_end=True
- if cfg.val_set_size > 0 and save_steps % eval_steps == 0 and cfg.load_in_8bit is not True
+ if cfg.val_set_size > 0
+ and save_steps is not None
+ and save_steps % eval_steps == 0
+ and cfg.load_in_8bit is not True
else False,
ddp_find_unused_parameters=False if cfg.ddp else None,
group_by_length=cfg.group_by_length,
report_to="wandb" if cfg.use_wandb else None,
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
optim=cfg.optimizer if cfg.optimizer else "adamw_hf",
- lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
+ lr_scheduler_type=cfg.lr_scheduler
+ if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
+ else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
**training_arguments_kwargs,
)
@@ -158,6 +157,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
cfg.learning_rate,
total_steps=total_num_steps,
epochs=cfg.num_epochs,
+ div_factor=10,
**lr_scheduler_kwargs,
)
elif cfg.lr_scheduler == "log_sweep":
diff --git a/src/axolotl/utils/wandb.py b/src/axolotl/utils/wandb.py
index 1e805c6c6..992bb1a5f 100644
--- a/src/axolotl/utils/wandb.py
+++ b/src/axolotl/utils/wandb.py
@@ -2,7 +2,9 @@ import os
def setup_wandb_env_vars(cfg):
- if cfg.wandb_project and len(cfg.wandb_project) > 0:
+ if cfg.wandb_mode and cfg.wandb_mode == "offline":
+ os.environ["WANDB_MODE"] = cfg.wandb_mode
+ elif cfg.wandb_project and len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project
cfg.use_wandb = True
if cfg.wandb_watch and len(cfg.wandb_watch) > 0: