black formatting
This commit is contained in:
@@ -6,12 +6,13 @@ import fire
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
# add src to the pythonpath so we don't need to pip install this
|
# add src to the pythonpath so we don't need to pip install this
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
src_dir = os.path.join(project_root, 'src')
|
src_dir = os.path.join(project_root, "src")
|
||||||
sys.path.insert(0, src_dir)
|
sys.path.insert(0, src_dir)
|
||||||
|
|
||||||
from axolotl.convert import *
|
from axolotl.convert import *
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
input: Path,
|
input: Path,
|
||||||
output: Optional[Path] = None,
|
output: Optional[Path] = None,
|
||||||
@@ -25,9 +26,7 @@ def main(
|
|||||||
json_parser = JsonParser()
|
json_parser = JsonParser()
|
||||||
jsonl_serializer = JsonlSerializer()
|
jsonl_serializer = JsonlSerializer()
|
||||||
|
|
||||||
converter = JsonToJsonlConverter(
|
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
|
||||||
file_reader, writer, json_parser, jsonl_serializer
|
|
||||||
)
|
|
||||||
|
|
||||||
converter.convert(input, output)
|
converter.convert(input, output)
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,8 @@ from datasets import load_dataset, IterableDataset, Dataset
|
|||||||
from peft import (
|
from peft import (
|
||||||
LoraConfig,
|
LoraConfig,
|
||||||
get_peft_model,
|
get_peft_model,
|
||||||
prepare_model_for_int8_training, get_peft_model_state_dict,
|
prepare_model_for_int8_training,
|
||||||
|
get_peft_model_state_dict,
|
||||||
)
|
)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
@@ -22,15 +23,20 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||||||
# add src to the pythonpath so we don't need to pip install this
|
# add src to the pythonpath so we don't need to pip install this
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
src_dir = os.path.join(project_root, 'src')
|
src_dir = os.path.join(project_root, "src")
|
||||||
sys.path.insert(0, src_dir)
|
sys.path.insert(0, src_dir)
|
||||||
|
|
||||||
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
||||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, \
|
from axolotl.prompt_tokenizers import (
|
||||||
LLAMA_DEFAULT_PAD_TOKEN, GPTeacherPromptTokenizingStrategy
|
AlpacaPromptTokenizingStrategy,
|
||||||
|
ShareGPTPromptTokenizingStrategy,
|
||||||
|
LLAMA_DEFAULT_PAD_TOKEN,
|
||||||
|
GPTeacherPromptTokenizingStrategy,
|
||||||
|
)
|
||||||
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
||||||
|
|
||||||
|
|
||||||
def setup_wandb_env_vars(cfg):
|
def setup_wandb_env_vars(cfg):
|
||||||
if len(cfg.wandb_project) > 0:
|
if len(cfg.wandb_project) > 0:
|
||||||
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
||||||
@@ -68,7 +74,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
|
|||||||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
||||||
|
|
||||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||||
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
if cfg.load_in_8bit:
|
if cfg.load_in_8bit:
|
||||||
@@ -94,11 +100,11 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
|
|||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
config: Path = Path('configs/pythia_1_2B_alpaca.yml'),
|
config: Path = Path("configs/pythia_1_2B_alpaca.yml"),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# load the config from the yaml file
|
# load the config from the yaml file
|
||||||
with open(config, 'r') as f:
|
with open(config, "r") as f:
|
||||||
cfg: AttrDict = AttrDict(yaml.load(f, Loader=yaml.Loader))
|
cfg: AttrDict = AttrDict(yaml.load(f, Loader=yaml.Loader))
|
||||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||||
# then overwrite the value
|
# then overwrite the value
|
||||||
@@ -114,36 +120,52 @@ def train(
|
|||||||
cfg.ddp = cfg.world_size != 1
|
cfg.ddp = cfg.world_size != 1
|
||||||
if cfg.ddp:
|
if cfg.ddp:
|
||||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps // cfg.world_size
|
cfg.gradient_accumulation_steps = (
|
||||||
|
cfg.gradient_accumulation_steps // cfg.world_size
|
||||||
|
)
|
||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
model, tokenizer, lora_config = load_model(cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter)
|
model, tokenizer, lora_config = load_model(
|
||||||
|
cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter
|
||||||
|
)
|
||||||
datasets = []
|
datasets = []
|
||||||
for d in cfg.datasets:
|
for d in cfg.datasets:
|
||||||
ds: IterableDataset = load_dataset("json", data_files=d.path, streaming=True, split=None)
|
ds: IterableDataset = load_dataset(
|
||||||
|
"json", data_files=d.path, streaming=True, split=None
|
||||||
|
)
|
||||||
if d.type == "alpaca":
|
if d.type == "alpaca":
|
||||||
ds_strategy = AlpacaPromptTokenizingStrategy(AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
|
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||||
|
AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||||
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
elif d.type == "gpteacher":
|
elif d.type == "gpteacher":
|
||||||
ds_strategy = GPTeacherPromptTokenizingStrategy(GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
|
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
||||||
|
GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||||
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
elif d.type == "sharegpt":
|
elif d.type == "sharegpt":
|
||||||
ds_strategy = ShareGPTPromptTokenizingStrategy(ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
|
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
||||||
|
ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||||
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
constant_len_dataset = ConstantLengthDataset(tokenizer, datasets, seq_length=cfg.sequence_len)
|
constant_len_dataset = ConstantLengthDataset(
|
||||||
constant_len_dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
|
tokenizer, datasets, seq_length=cfg.sequence_len
|
||||||
test_size=cfg.val_set_size, shuffle=True, seed=42
|
|
||||||
)
|
)
|
||||||
|
constant_len_dataset = Dataset.from_list(
|
||||||
|
[_ for _ in constant_len_dataset]
|
||||||
|
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
|
||||||
|
|
||||||
print(constant_len_dataset)
|
print(constant_len_dataset)
|
||||||
train_dataset = constant_len_dataset["train"]
|
train_dataset = constant_len_dataset["train"]
|
||||||
eval_dataset = constant_len_dataset["test"]
|
eval_dataset = constant_len_dataset["test"]
|
||||||
|
|
||||||
total_num_steps = int(math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size))
|
total_num_steps = int(
|
||||||
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
|
)
|
||||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||||
logging_steps = min(int(0.005 * total_num_steps), 10)
|
logging_steps = min(int(0.005 * total_num_steps), 10)
|
||||||
save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
|
save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
|
||||||
@@ -178,7 +200,9 @@ def train(
|
|||||||
"weight_decay": training_args.weight_decay,
|
"weight_decay": training_args.weight_decay,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"params": [p for n, p in model.named_parameters() if n not in decay_parameters],
|
"params": [
|
||||||
|
p for n, p in model.named_parameters() if n not in decay_parameters
|
||||||
|
],
|
||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
@@ -210,18 +234,16 @@ def train(
|
|||||||
|
|
||||||
old_state_dict = model.state_dict
|
old_state_dict = model.state_dict
|
||||||
model.state_dict = (
|
model.state_dict = (
|
||||||
lambda self, *_, **__: get_peft_model_state_dict(
|
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
||||||
self, old_state_dict()
|
|
||||||
)
|
|
||||||
).__get__(model, type(model))
|
).__get__(model, type(model))
|
||||||
|
|
||||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||||
model = torch.compile(model)
|
model = torch.compile(model)
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, lambda signal, frame: (
|
signal.signal(
|
||||||
model.save_pretrained(cfg.output_dir),
|
signal.SIGINT,
|
||||||
exit(0)
|
lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
|
||||||
))
|
)
|
||||||
|
|
||||||
# go ahead and presave the adapter config
|
# go ahead and presave the adapter config
|
||||||
lora_config.save_pretrained(cfg.output_dir)
|
lora_config.save_pretrained(cfg.output_dir)
|
||||||
@@ -229,5 +251,6 @@ def train(
|
|||||||
|
|
||||||
model.save_pretrained(cfg.output_dir)
|
model.save_pretrained(cfg.output_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(train)
|
fire.Fire(train)
|
||||||
|
|||||||
@@ -47,5 +47,3 @@ class JsonToJsonlConverter:
|
|||||||
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations
|
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations
|
||||||
jsonl_content = self.jsonl_serializer.serialize(data)
|
jsonl_content = self.jsonl_serializer.serialize(data)
|
||||||
self.file_writer.write(jsonl_content)
|
self.file_writer.write(jsonl_content)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -71,10 +71,18 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
else:
|
else:
|
||||||
example_len = 0
|
example_len = 0
|
||||||
|
|
||||||
if not example_len or buffer_len + int(add_concat_token) + example_len > self.seq_length:
|
if (
|
||||||
|
not example_len
|
||||||
|
or buffer_len + int(add_concat_token) + example_len
|
||||||
|
> self.seq_length
|
||||||
|
):
|
||||||
if buffer["input_ids"]:
|
if buffer["input_ids"]:
|
||||||
input_ids = torch.cat(buffer["input_ids"], dim=-1)[: self.seq_length]
|
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
|
||||||
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[: self.seq_length]
|
: self.seq_length
|
||||||
|
]
|
||||||
|
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
|
||||||
|
: self.seq_length
|
||||||
|
]
|
||||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||||
yield {
|
yield {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
@@ -95,7 +103,9 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
labels.append(self.concat_token_id)
|
labels.append(self.concat_token_id)
|
||||||
|
|
||||||
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
|
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
|
||||||
attention_mask_with_concat = torch.tensor(attention_mask, dtype=torch.long)
|
attention_mask_with_concat = torch.tensor(
|
||||||
|
attention_mask, dtype=torch.long
|
||||||
|
)
|
||||||
labels_with_concat = torch.tensor(labels, dtype=torch.long)
|
labels_with_concat = torch.tensor(labels, dtype=torch.long)
|
||||||
|
|
||||||
buffer["input_ids"].append(input_ids_with_concat)
|
buffer["input_ids"].append(input_ids_with_concat)
|
||||||
|
|||||||
@@ -42,7 +42,9 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||||
# TODO this could be sped up using numpy array slicing
|
# TODO this could be sped up using numpy array slicing
|
||||||
tokenized_full_prompt["labels"] = [-100] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
tokenized_full_prompt["labels"] = [
|
||||||
|
-100
|
||||||
|
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
||||||
|
|
||||||
return tokenized_full_prompt
|
return tokenized_full_prompt
|
||||||
|
|
||||||
|
|||||||
@@ -20,13 +20,9 @@ class AlpacaPrompter:
|
|||||||
# returns the full prompt from instruction and optional input
|
# returns the full prompt from instruction and optional input
|
||||||
# if a label (=response, =output) is provided, it's also appended.
|
# if a label (=response, =output) is provided, it's also appended.
|
||||||
if input:
|
if input:
|
||||||
res = self.prompt_input.format(
|
res = self.prompt_input.format(instruction=instruction, input=input)
|
||||||
instruction=instruction, input=input
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
res = self.prompt_no_input.format(
|
res = self.prompt_no_input.format(instruction=instruction)
|
||||||
instruction=instruction
|
|
||||||
)
|
|
||||||
if output:
|
if output:
|
||||||
res = f"{res}{output}"
|
res = f"{res}{output}"
|
||||||
return res
|
return res
|
||||||
@@ -41,6 +37,7 @@ class GPTeacherPrompter(AlpacaPrompter):
|
|||||||
|
|
||||||
class SeparatorStyle(Enum):
|
class SeparatorStyle(Enum):
|
||||||
"""Different separator style."""
|
"""Different separator style."""
|
||||||
|
|
||||||
SINGLE = auto()
|
SINGLE = auto()
|
||||||
TWO = auto()
|
TWO = auto()
|
||||||
DOLLY = auto()
|
DOLLY = auto()
|
||||||
@@ -50,6 +47,7 @@ class SeparatorStyle(Enum):
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Conversation:
|
class Conversation:
|
||||||
"""A class that keeps all conversation history."""
|
"""A class that keeps all conversation history."""
|
||||||
|
|
||||||
system: str
|
system: str
|
||||||
roles: List[str]
|
roles: List[str]
|
||||||
messages: List[List[str]]
|
messages: List[List[str]]
|
||||||
@@ -85,7 +83,7 @@ class Conversation:
|
|||||||
|
|
||||||
conv_vicuna_v1_1 = Conversation(
|
conv_vicuna_v1_1 = Conversation(
|
||||||
system="A chat between a curious user and an artificial intelligence assistant. "
|
system="A chat between a curious user and an artificial intelligence assistant. "
|
||||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||||
roles=["USER", "ASSISTANT"],
|
roles=["USER", "ASSISTANT"],
|
||||||
messages=[],
|
messages=[],
|
||||||
offset=0,
|
offset=0,
|
||||||
@@ -96,11 +94,7 @@ conv_vicuna_v1_1 = Conversation(
|
|||||||
|
|
||||||
|
|
||||||
class ShareGPTPrompter:
|
class ShareGPTPrompter:
|
||||||
def build_prompt(
|
def build_prompt(self, source, tokenizer):
|
||||||
self,
|
|
||||||
source,
|
|
||||||
tokenizer
|
|
||||||
):
|
|
||||||
if len(source) < 2:
|
if len(source) < 2:
|
||||||
# If there isn't a back and forth conversation, ignore it
|
# If there isn't a back and forth conversation, ignore it
|
||||||
# also happens on the data splitting leaving empty conversations
|
# also happens on the data splitting leaving empty conversations
|
||||||
@@ -111,7 +105,10 @@ class ShareGPTPrompter:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Apply prompt templates
|
# Apply prompt templates
|
||||||
if source[0]["from"] not in roles or roles[source[0]["from"]] != conv.roles[0]:
|
if (
|
||||||
|
source[0]["from"] not in roles
|
||||||
|
or roles[source[0]["from"]] != conv.roles[0]
|
||||||
|
):
|
||||||
# Skip the first one if it is not from human
|
# Skip the first one if it is not from human
|
||||||
source = source[1:]
|
source = source[1:]
|
||||||
except IndexError as e:
|
except IndexError as e:
|
||||||
@@ -150,11 +147,19 @@ class ShareGPTPrompter:
|
|||||||
parts[0] += sep
|
parts[0] += sep
|
||||||
round_len = len(tokenizer(rou)["input_ids"])
|
round_len = len(tokenizer(rou)["input_ids"])
|
||||||
instruction_len = len(tokenizer(parts[0])["input_ids"]) - 2
|
instruction_len = len(tokenizer(parts[0])["input_ids"]) - 2
|
||||||
target[cur_len:cur_len+instruction_len] = [IGNORE_TOKEN_ID] * instruction_len
|
target[cur_len : cur_len + instruction_len] = [
|
||||||
|
IGNORE_TOKEN_ID
|
||||||
|
] * instruction_len
|
||||||
|
|
||||||
cur_len += round_len
|
cur_len += round_len
|
||||||
target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
|
target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
|
||||||
attention_mask = [1 if x != tokenizer.pad_token_id else 0 for x in tokenized_result["input_ids"]]
|
attention_mask = [
|
||||||
|
1 if x != tokenizer.pad_token_id else 0
|
||||||
|
for x in tokenized_result["input_ids"]
|
||||||
|
]
|
||||||
|
|
||||||
return dict(input_ids=tokenized_result["input_ids"], labels=target,
|
return dict(
|
||||||
attention_mask=attention_mask)
|
input_ids=tokenized_result["input_ids"],
|
||||||
|
labels=target,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user