fix merge conflict failure, black format
This commit is contained in:
@@ -112,14 +112,10 @@ def load_tokenized_prepared_datasets(
|
||||
raise Exception("unhandled dataset load")
|
||||
# support for using a subset of the data
|
||||
if d.shards:
|
||||
<<<<<<< Updated upstream
|
||||
ds = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
|
||||
=======
|
||||
if "train" in ds:
|
||||
ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
|
||||
else:
|
||||
ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0)
|
||||
>>>>>>> Stashed changes
|
||||
d_type = d.type
|
||||
d_type_split = d_type.split(":")
|
||||
d_base_type = d_type_split[0]
|
||||
|
||||
@@ -247,8 +247,10 @@ def load_model(
|
||||
model.resize_token_embeddings(embeddings_len)
|
||||
|
||||
if (
|
||||
(cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
|
||||
) and not cfg.load_4bit and (load_in_8bit or cfg.load_in_4bit):
|
||||
((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
|
||||
and not cfg.load_4bit
|
||||
and (load_in_8bit or cfg.load_in_4bit)
|
||||
):
|
||||
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
||||
model = prepare_model_for_int8_training(model)
|
||||
|
||||
@@ -297,7 +299,7 @@ def load_adapter(model, cfg, adapter):
|
||||
|
||||
if adapter is None:
|
||||
return model, None
|
||||
if adapter in ["lora" , "qlora"]:
|
||||
if adapter in ["lora", "qlora"]:
|
||||
return load_lora(model, cfg)
|
||||
if adapter == "llama-adapter":
|
||||
return load_llama_adapter(model, cfg)
|
||||
|
||||
Reference in New Issue
Block a user