fix merge conflict failure, black format

This commit is contained in:
Wing Lian
2023-05-25 22:40:27 -04:00
parent 3f6017db9e
commit 7b5e762be2
2 changed files with 5 additions and 7 deletions

View File

@@ -112,14 +112,10 @@ def load_tokenized_prepared_datasets(
raise Exception("unhandled dataset load")
# support for using a subset of the data
if d.shards:
<<<<<<< Updated upstream
ds = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
=======
if "train" in ds:
ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
else:
ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0)
>>>>>>> Stashed changes
d_type = d.type
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]

View File

@@ -247,8 +247,10 @@ def load_model(
model.resize_token_embeddings(embeddings_len)
if (
(cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
) and not cfg.load_4bit and (load_in_8bit or cfg.load_in_4bit):
((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
and not cfg.load_4bit
and (load_in_8bit or cfg.load_in_4bit)
):
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
model = prepare_model_for_int8_training(model)
@@ -297,7 +299,7 @@ def load_adapter(model, cfg, adapter):
if adapter is None:
return model, None
if adapter in ["lora" , "qlora"]:
if adapter in ["lora", "qlora"]:
return load_lora(model, cfg)
if adapter == "llama-adapter":
return load_llama_adapter(model, cfg)