Merge pull request #26 from OpenAccess-AI-Collective/mpt-triton
Mpt triton
This commit is contained in:
@@ -191,7 +191,9 @@ def train(
|
||||
if cfg.debug:
|
||||
logging.info("check_dataset_labels...")
|
||||
check_dataset_labels(
|
||||
train_dataset.select([random.randrange(0, len(train_dataset) - 1) for i in range(5)]),
|
||||
train_dataset.select(
|
||||
[random.randrange(0, len(train_dataset) - 1) for i in range(5)]
|
||||
),
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
@@ -218,17 +220,20 @@ def train(
|
||||
logging.info("Starting trainer...")
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")]
|
||||
possible_checkpoints = [
|
||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
]
|
||||
if len(possible_checkpoints) > 0:
|
||||
sorted_paths = sorted(possible_checkpoints, key=lambda path: int(path.split('-')[-1]))
|
||||
sorted_paths = sorted(
|
||||
possible_checkpoints, key=lambda path: int(path.split("-")[-1])
|
||||
)
|
||||
resume_from_checkpoint = sorted_paths[-1]
|
||||
logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}")
|
||||
logging.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
|
||||
)
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
logging.info(
|
||||
f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
|
||||
)
|
||||
|
||||
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
trainer.save_pretrained(cfg.output_dir)
|
||||
|
||||
18
setup.py
18
setup.py
@@ -10,22 +10,22 @@ with open("./requirements.txt", "r") as requirements_file:
|
||||
install_requires.append(r)
|
||||
|
||||
setup(
|
||||
name='axolotl',
|
||||
version='0.1',
|
||||
name="axolotl",
|
||||
version="0.1",
|
||||
description="You know you're going to axolotl questions",
|
||||
package_dir={'': 'src'},
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages(),
|
||||
install_requires=install_requires,
|
||||
extras_require={
|
||||
'int4': [
|
||||
"int4": [
|
||||
"alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
||||
],
|
||||
'int4_triton': [
|
||||
"int4_triton": [
|
||||
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
||||
],
|
||||
'extras': [
|
||||
'flash-attn',
|
||||
'deepspeed',
|
||||
]
|
||||
"extras": [
|
||||
"flash-attn",
|
||||
"deepspeed",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -31,6 +31,7 @@ class TokenizedPromptDataset(IterableDataset):
|
||||
except InvalidDataException:
|
||||
pass
|
||||
|
||||
|
||||
# TODO this isn't the best since it can't interleave datasets
|
||||
class ConstantLengthDataset(IterableDataset):
|
||||
"""
|
||||
@@ -40,6 +41,7 @@ class ConstantLengthDataset(IterableDataset):
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
seq_length (int): Length of token sequences to return.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
@@ -93,14 +95,19 @@ class ConstantLengthDataset(IterableDataset):
|
||||
: self.seq_length
|
||||
]
|
||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||
if labels.size() == input_ids.size() and attention_mask.size() == input_ids.size():
|
||||
if (
|
||||
labels.size() == input_ids.size()
|
||||
and attention_mask.size() == input_ids.size()
|
||||
):
|
||||
yield {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
else:
|
||||
logging.warning("dropping batch due to tensor size mismatch")
|
||||
logging.warning(
|
||||
"dropping batch due to tensor size mismatch"
|
||||
)
|
||||
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
buffer_len = 0
|
||||
|
||||
@@ -116,11 +123,15 @@ class ConstantLengthDataset(IterableDataset):
|
||||
attention_mask.append(1)
|
||||
labels.append(self.concat_token_id)
|
||||
|
||||
input_ids_with_concat = torch.tensor(input_ids, dtype=self.tokens_dtype)
|
||||
input_ids_with_concat = torch.tensor(
|
||||
input_ids, dtype=self.tokens_dtype
|
||||
)
|
||||
attention_mask_with_concat = torch.tensor(
|
||||
attention_mask, dtype=self.tokens_dtype
|
||||
)
|
||||
labels_with_concat = torch.tensor(labels, dtype=self.tokens_dtype)
|
||||
labels_with_concat = torch.tensor(
|
||||
labels, dtype=self.tokens_dtype
|
||||
)
|
||||
|
||||
buffer["input_ids"].append(input_ids_with_concat)
|
||||
buffer["attention_mask"].append(attention_mask_with_concat)
|
||||
|
||||
@@ -126,10 +126,8 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
|
||||
|
||||
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str):
|
||||
return (
|
||||
prompt["text"]
|
||||
)
|
||||
def parse_instruction_fields(self, prompt) -> str:
|
||||
return prompt["text"]
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
instruction = self.parse_instruction_fields(prompt)
|
||||
@@ -139,9 +137,7 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
return tokenized_full_prompt
|
||||
|
||||
def _build_full_prompt(self, instruction):
|
||||
return self.prompter.build_prompt(
|
||||
instruction
|
||||
)
|
||||
return self.prompter.build_prompt(instruction)
|
||||
|
||||
|
||||
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
@@ -149,8 +145,16 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
instruction, input, output, reflection, corrected = self.parse_instruction_fields(prompt)
|
||||
full_prompt = self._build_full_prompt(instruction, input, output, reflection, corrected)
|
||||
(
|
||||
instruction,
|
||||
input,
|
||||
output,
|
||||
reflection,
|
||||
corrected,
|
||||
) = self.parse_instruction_fields(prompt)
|
||||
full_prompt = self._build_full_prompt(
|
||||
instruction, input, output, reflection, corrected
|
||||
)
|
||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||
if not self.train_on_inputs:
|
||||
user_prompt = self.prompter.build_prompt(
|
||||
|
||||
@@ -36,10 +36,7 @@ class JeopardyPrompter(AlpacaPrompter):
|
||||
|
||||
|
||||
class CompletionPrompter(AlpacaPrompter):
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str
|
||||
) -> str:
|
||||
def build_prompt(self, instruction: str) -> str:
|
||||
return instruction
|
||||
|
||||
def get_response(self, output: str) -> str:
|
||||
@@ -75,7 +72,9 @@ class ReflectAlpacaPrompter:
|
||||
else:
|
||||
res = self.prompt_no_input.format(instruction=instruction)
|
||||
if output and reflection and corrected:
|
||||
label = self.agent_label.format(output=output, reflection=reflection, corrected=corrected)
|
||||
label = self.agent_label.format(
|
||||
output=output, reflection=reflection, corrected=corrected
|
||||
)
|
||||
res = f"{res}{label}"
|
||||
return res
|
||||
|
||||
@@ -200,9 +199,13 @@ class ShareGPTPrompter:
|
||||
if len(parts) != 2:
|
||||
break
|
||||
parts[0] += sep
|
||||
round_len = len(tokenizer(rou)["input_ids"]) - 1 # -1 ignores the bos_token generated for this
|
||||
round_len = (
|
||||
len(tokenizer(rou)["input_ids"]) - 1
|
||||
) # -1 ignores the bos_token generated for this
|
||||
# we have to strip the initial part, any dangling whitespace creates an additional ghost token
|
||||
instruction_len = len(tokenizer(parts[0].strip())["input_ids"]) - 1 # -1 ignores the bos_token generated for this
|
||||
instruction_len = (
|
||||
len(tokenizer(parts[0].strip())["input_ids"]) - 1
|
||||
) # -1 ignores the bos_token generated for this
|
||||
target[cur_len : cur_len + instruction_len] = [
|
||||
IGNORE_TOKEN_ID
|
||||
] * instruction_len
|
||||
@@ -212,7 +215,7 @@ class ShareGPTPrompter:
|
||||
break
|
||||
|
||||
# Fix: Truncate the target to have the same length as input_ids
|
||||
target = target[:len(tokenized_result["input_ids"])]
|
||||
target = target[: len(tokenized_result["input_ids"])]
|
||||
# target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
|
||||
|
||||
attention_mask = [
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
import os
|
||||
|
||||
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
|
||||
from transformers import (
|
||||
Seq2SeqTrainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
TrainerState,
|
||||
TrainerControl,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback):
|
||||
def on_save(
|
||||
self,
|
||||
@@ -11,7 +18,9 @@ class SavePeftModelCallback(TrainerCallback):
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
|
||||
)
|
||||
|
||||
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
||||
kwargs["model"].save_pretrained(peft_model_path)
|
||||
|
||||
@@ -2,7 +2,13 @@ import logging
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import load_from_disk, load_dataset, IterableDataset, Dataset, concatenate_datasets
|
||||
from datasets import (
|
||||
load_from_disk,
|
||||
load_dataset,
|
||||
IterableDataset,
|
||||
Dataset,
|
||||
concatenate_datasets,
|
||||
)
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
||||
@@ -75,7 +81,9 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
||||
else:
|
||||
ds = load_dataset(d.path, streaming=True)
|
||||
else:
|
||||
fp = hf_hub_download(repo_id=d.path, repo_type="dataset", filename=d.data_files)
|
||||
fp = hf_hub_download(
|
||||
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
||||
)
|
||||
ds = load_dataset("json", data_files=fp, streaming=True, split=None)
|
||||
if not ds:
|
||||
raise Exception("unhandled dataset load")
|
||||
@@ -140,7 +148,9 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
||||
samples = samples + [i for i in d]
|
||||
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
||||
if cfg.local_rank == 0:
|
||||
logging.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||
logging.info(
|
||||
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
|
||||
)
|
||||
dataset.save_to_disk(prepared_ds_path)
|
||||
|
||||
if cfg.max_packed_sequence_len is not None:
|
||||
@@ -153,12 +163,14 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
||||
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
||||
|
||||
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
||||
logging.info(f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards")
|
||||
dataset = dataset.shard(num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx)
|
||||
logging.info(
|
||||
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
|
||||
)
|
||||
dataset = dataset.shard(
|
||||
num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx
|
||||
)
|
||||
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=cfg.val_set_size, shuffle=False
|
||||
)
|
||||
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
|
||||
|
||||
@@ -9,14 +9,18 @@ from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
PreTrainedModel,
|
||||
AutoConfig,
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers import (
|
||||
LlamaForCausalLM,
|
||||
LlamaTokenizer,
|
||||
)
|
||||
except:
|
||||
logging.warning("This version of transformers does not support Llama. Consider upgrading.")
|
||||
logging.warning(
|
||||
"This version of transformers does not support Llama. Consider upgrading."
|
||||
)
|
||||
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||
|
||||
@@ -40,7 +44,9 @@ def load_model(
|
||||
# TODO refactor as a kwarg
|
||||
load_in_8bit = cfg.load_in_8bit
|
||||
tokenizer = None
|
||||
is_llama_derived_model = "llama" in base_model or (cfg.model_type and "llama" in cfg.model_type.lower())
|
||||
is_llama_derived_model = "llama" in base_model or (
|
||||
cfg.model_type and "llama" in cfg.model_type.lower()
|
||||
)
|
||||
|
||||
if is_llama_derived_model and cfg.flash_attention:
|
||||
if cfg.device not in ["mps", "cpu"] and inference is False:
|
||||
@@ -49,11 +55,16 @@ def load_model(
|
||||
logging.info("patching with flash attention")
|
||||
replace_llama_attn_with_flash_attn()
|
||||
elif is_llama_derived_model and cfg.xformers_attention:
|
||||
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention
|
||||
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_attention,
|
||||
)
|
||||
|
||||
logging.info("patching with xformers attention")
|
||||
hijack_llama_attention()
|
||||
|
||||
torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
|
||||
torch_dtype = (
|
||||
torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
|
||||
)
|
||||
try:
|
||||
if cfg.load_4bit:
|
||||
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
||||
@@ -74,8 +85,12 @@ def load_model(
|
||||
try:
|
||||
snapshot_download_kwargs = {}
|
||||
if cfg.base_model_ignore_patterns:
|
||||
snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns
|
||||
cache_model_path = Path(snapshot_download(base_model, **snapshot_download_kwargs))
|
||||
snapshot_download_kwargs[
|
||||
"ignore_patterns"
|
||||
] = cfg.base_model_ignore_patterns
|
||||
cache_model_path = Path(
|
||||
snapshot_download(base_model, **snapshot_download_kwargs)
|
||||
)
|
||||
files = (
|
||||
list(cache_model_path.glob("*.pt"))
|
||||
+ list(cache_model_path.glob("*.safetensors"))
|
||||
@@ -116,8 +131,13 @@ def load_model(
|
||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||
)
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(
|
||||
base_model,
|
||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
|
||||
@@ -26,7 +26,10 @@ class InterpolatingLogScheduler(LRScheduler):
|
||||
if self.last_epoch <= 0:
|
||||
lrs = [self.min_lr for base_lr in self.base_lrs]
|
||||
elif self.last_epoch < self.num_steps:
|
||||
lrs = [self.min_lr * (self.q ** (self.last_epoch - 1)) for base_lr in self.base_lrs]
|
||||
lrs = [
|
||||
self.min_lr * (self.q ** (self.last_epoch - 1))
|
||||
for base_lr in self.base_lrs
|
||||
]
|
||||
else:
|
||||
lrs = [self.max_lr for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from termcolor import colored
|
||||
import logging
|
||||
|
||||
|
||||
def check_dataset_labels(dataset, tokenizer):
|
||||
# the dataset is already shuffled, so let's just check the first 5 elements
|
||||
for idx in range(5):
|
||||
@@ -11,7 +12,7 @@ def check_example_labels(example, tokenizer):
|
||||
# Get the input_ids, labels, and attention_mask from the dataset
|
||||
input_ids = example["input_ids"]
|
||||
labels = example["labels"]
|
||||
attention_mask =example["attention_mask"]
|
||||
attention_mask = example["attention_mask"]
|
||||
|
||||
# You can compare the input_ids and labels element-wise
|
||||
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
|
||||
@@ -21,9 +22,7 @@ def check_example_labels(example, tokenizer):
|
||||
):
|
||||
decoded_input_token = tokenizer.decode(input_id)
|
||||
# Choose the color based on whether the label has the ignore value or not
|
||||
color = (
|
||||
"red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
|
||||
)
|
||||
color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
|
||||
colored_token = colored(decoded_input_token, color) + colored(
|
||||
f"({label_id}, {mask}, {input_id})", "white"
|
||||
)
|
||||
|
||||
@@ -30,16 +30,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
if cfg.logging_steps is not None
|
||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||
)
|
||||
save_steps = (
|
||||
cfg.save_steps
|
||||
if cfg.save_steps is not None
|
||||
else min(int(0.05 * total_num_steps), 200)
|
||||
)
|
||||
eval_steps = (
|
||||
cfg.eval_steps
|
||||
if cfg.eval_steps is not None and save_steps % cfg.eval_steps == 0
|
||||
else save_steps
|
||||
)
|
||||
save_steps = cfg.save_steps
|
||||
eval_steps = cfg.eval_steps
|
||||
|
||||
training_arguments_kwargs = {}
|
||||
if cfg.bf16 == "full":
|
||||
@@ -86,26 +78,33 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
|
||||
training_args = transformers.TrainingArguments(
|
||||
per_device_train_batch_size=cfg.micro_batch_size,
|
||||
per_device_eval_batch_size=cfg.eval_batch_size if cfg.eval_batch_size is not None else cfg.micro_batch_size,
|
||||
per_device_eval_batch_size=cfg.eval_batch_size
|
||||
if cfg.eval_batch_size is not None
|
||||
else cfg.micro_batch_size,
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
num_train_epochs=cfg.num_epochs,
|
||||
learning_rate=cfg.learning_rate,
|
||||
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
|
||||
save_strategy="steps",
|
||||
save_strategy="steps" if save_steps else "epoch",
|
||||
eval_steps=eval_steps if cfg.val_set_size > 0 else None,
|
||||
save_steps=save_steps,
|
||||
output_dir=cfg.output_dir,
|
||||
save_total_limit=3,
|
||||
load_best_model_at_end=True
|
||||
if cfg.val_set_size > 0 and save_steps % eval_steps == 0 and cfg.load_in_8bit is not True
|
||||
if cfg.val_set_size > 0
|
||||
and save_steps is not None
|
||||
and save_steps % eval_steps == 0
|
||||
and cfg.load_in_8bit is not True
|
||||
else False,
|
||||
ddp_find_unused_parameters=False if cfg.ddp else None,
|
||||
group_by_length=cfg.group_by_length,
|
||||
report_to="wandb" if cfg.use_wandb else None,
|
||||
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
||||
optim=cfg.optimizer if cfg.optimizer else "adamw_hf",
|
||||
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
|
||||
lr_scheduler_type=cfg.lr_scheduler
|
||||
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
||||
else "cosine",
|
||||
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
@@ -158,6 +157,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
cfg.learning_rate,
|
||||
total_steps=total_num_steps,
|
||||
epochs=cfg.num_epochs,
|
||||
div_factor=10,
|
||||
**lr_scheduler_kwargs,
|
||||
)
|
||||
elif cfg.lr_scheduler == "log_sweep":
|
||||
@@ -191,7 +191,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 8
|
||||
|
||||
callbacks = []
|
||||
if cfg.adapter == 'lora':
|
||||
if cfg.adapter == "lora":
|
||||
callbacks.append(SavePeftModelCallback)
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
|
||||
@@ -2,7 +2,9 @@ import os
|
||||
|
||||
|
||||
def setup_wandb_env_vars(cfg):
|
||||
if cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||
if cfg.wandb_mode and cfg.wandb_mode == "offline":
|
||||
os.environ["WANDB_MODE"] = cfg.wandb_mode
|
||||
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
||||
cfg.use_wandb = True
|
||||
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
||||
|
||||
Reference in New Issue
Block a user