fix relative path for fixtures
This commit is contained in:
@@ -129,6 +129,7 @@ def load_model(
|
|||||||
llm_int8_threshold=6.0,
|
llm_int8_threshold=6.0,
|
||||||
llm_int8_has_fp16_weight=False,
|
llm_int8_has_fp16_weight=False,
|
||||||
bnb_4bit_compute_dtype=torch_dtype,
|
bnb_4bit_compute_dtype=torch_dtype,
|
||||||
|
bnb_4bit_compute_dtype=torch_dtype,
|
||||||
bnb_4bit_use_double_quant=True,
|
bnb_4bit_use_double_quant=True,
|
||||||
bnb_4bit_quant_type="nf4",
|
bnb_4bit_quant_type="nf4",
|
||||||
)
|
)
|
||||||
@@ -280,8 +281,8 @@ def load_model(
|
|||||||
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
||||||
# so let's only set it for the 4bit, see
|
# so let's only set it for the 4bit, see
|
||||||
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
||||||
setattr(model, 'is_parallelizable', True)
|
setattr(model, "is_parallelizable", True)
|
||||||
setattr(model, 'model_parallel', True)
|
setattr(model, "model_parallel", True)
|
||||||
|
|
||||||
requires_grad = []
|
requires_grad = []
|
||||||
for name, param in model.named_parameters(recurse=True):
|
for name, param in model.named_parameters(recurse=True):
|
||||||
|
|||||||
@@ -125,7 +125,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
output_dir=cfg.output_dir,
|
output_dir=cfg.output_dir,
|
||||||
save_total_limit=3,
|
save_total_limit=3,
|
||||||
load_best_model_at_end=(
|
load_best_model_at_end=(
|
||||||
cfg.val_set_size > 0
|
cfg.load_best_model_at_end is not False
|
||||||
|
and cfg.val_set_size > 0
|
||||||
and save_steps
|
and save_steps
|
||||||
and save_steps % eval_steps == 0
|
and save_steps % eval_steps == 0
|
||||||
and cfg.load_in_8bit is not True
|
and cfg.load_in_8bit is not True
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
|
"""Module for testing prompt tokenizers."""
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
@@ -12,6 +14,10 @@ logging.basicConfig(level="INFO")
|
|||||||
|
|
||||||
|
|
||||||
class TestPromptTokenizationStrategies(unittest.TestCase):
|
class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test class for prompt tokenization strategies.
|
||||||
|
"""
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
self.tokenizer.add_special_tokens(
|
self.tokenizer.add_special_tokens(
|
||||||
@@ -24,10 +30,15 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
|
|
||||||
def test_sharegpt_integration(self):
|
def test_sharegpt_integration(self):
|
||||||
print(Path(__file__).parent)
|
print(Path(__file__).parent)
|
||||||
with open(Path(__file__).parent / "fixtures/conversation.json", "r") as fin:
|
with open(
|
||||||
|
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
||||||
|
) as fin:
|
||||||
data = fin.read()
|
data = fin.read()
|
||||||
conversation = json.loads(data)
|
conversation = json.loads(data)
|
||||||
with open(Path(__file__).parent / "fixtures/conversation.tokenized.json", "r") as fin:
|
with open(
|
||||||
|
Path(__file__).parent / "fixtures/conversation.tokenized.json",
|
||||||
|
encoding="utf-8",
|
||||||
|
) as fin:
|
||||||
data = fin.read()
|
data = fin.read()
|
||||||
tokenized_conversation = json.loads(data)
|
tokenized_conversation = json.loads(data)
|
||||||
prompter = ShareGPTPrompter("chat")
|
prompter = ShareGPTPrompter("chat")
|
||||||
|
|||||||
Reference in New Issue
Block a user