fix merge conflict failure, black format

This commit is contained in:
Wing Lian
2023-05-25 22:40:27 -04:00
parent 3f6017db9e
commit 7b5e762be2
2 changed files with 5 additions and 7 deletions

View File

@@ -112,14 +112,10 @@ def load_tokenized_prepared_datasets(
raise Exception("unhandled dataset load") raise Exception("unhandled dataset load")
# support for using a subset of the data # support for using a subset of the data
if d.shards: if d.shards:
<<<<<<< Updated upstream
ds = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
=======
if "train" in ds: if "train" in ds:
ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0) ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
else: else:
ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0) ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0)
>>>>>>> Stashed changes
d_type = d.type d_type = d.type
d_type_split = d_type.split(":") d_type_split = d_type.split(":")
d_base_type = d_type_split[0] d_base_type = d_type_split[0]

View File

@@ -247,8 +247,10 @@ def load_model(
model.resize_token_embeddings(embeddings_len) model.resize_token_embeddings(embeddings_len)
if ( if (
(cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora" ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
) and not cfg.load_4bit and (load_in_8bit or cfg.load_in_4bit): and not cfg.load_4bit
and (load_in_8bit or cfg.load_in_4bit)
):
logging.info("converting PEFT model w/ prepare_model_for_int8_training") logging.info("converting PEFT model w/ prepare_model_for_int8_training")
model = prepare_model_for_int8_training(model) model = prepare_model_for_int8_training(model)
@@ -297,7 +299,7 @@ def load_adapter(model, cfg, adapter):
if adapter is None: if adapter is None:
return model, None return model, None
if adapter in ["lora" , "qlora"]: if adapter in ["lora", "qlora"]:
return load_lora(model, cfg) return load_lora(model, cfg)
if adapter == "llama-adapter": if adapter == "llama-adapter":
return load_llama_adapter(model, cfg) return load_llama_adapter(model, cfg)