fix merge conflict failure, black format
This commit is contained in:
@@ -112,14 +112,10 @@ def load_tokenized_prepared_datasets(
|
|||||||
raise Exception("unhandled dataset load")
|
raise Exception("unhandled dataset load")
|
||||||
# support for using a subset of the data
|
# support for using a subset of the data
|
||||||
if d.shards:
|
if d.shards:
|
||||||
<<<<<<< Updated upstream
|
|
||||||
ds = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
|
|
||||||
=======
|
|
||||||
if "train" in ds:
|
if "train" in ds:
|
||||||
ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
|
ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
|
||||||
else:
|
else:
|
||||||
ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0)
|
ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0)
|
||||||
>>>>>>> Stashed changes
|
|
||||||
d_type = d.type
|
d_type = d.type
|
||||||
d_type_split = d_type.split(":")
|
d_type_split = d_type.split(":")
|
||||||
d_base_type = d_type_split[0]
|
d_base_type = d_type_split[0]
|
||||||
|
|||||||
@@ -247,8 +247,10 @@ def load_model(
|
|||||||
model.resize_token_embeddings(embeddings_len)
|
model.resize_token_embeddings(embeddings_len)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
|
((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
|
||||||
) and not cfg.load_4bit and (load_in_8bit or cfg.load_in_4bit):
|
and not cfg.load_4bit
|
||||||
|
and (load_in_8bit or cfg.load_in_4bit)
|
||||||
|
):
|
||||||
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
||||||
model = prepare_model_for_int8_training(model)
|
model = prepare_model_for_int8_training(model)
|
||||||
|
|
||||||
@@ -297,7 +299,7 @@ def load_adapter(model, cfg, adapter):
|
|||||||
|
|
||||||
if adapter is None:
|
if adapter is None:
|
||||||
return model, None
|
return model, None
|
||||||
if adapter in ["lora" , "qlora"]:
|
if adapter in ["lora", "qlora"]:
|
||||||
return load_lora(model, cfg)
|
return load_lora(model, cfg)
|
||||||
if adapter == "llama-adapter":
|
if adapter == "llama-adapter":
|
||||||
return load_llama_adapter(model, cfg)
|
return load_llama_adapter(model, cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user