fix tokenizer loading, got openllama 3b working
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
base_model: huggyllama/llama-7b
|
base_model: openlm-research/open_llama_3b_600bt_preview
|
||||||
base_model_config: huggyllama/llama-7b
|
base_model_config: openlm-research/open_llama_3b_600bt_preview
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
load_in_8bit: true
|
load_in_8bit: true
|
||||||
@@ -32,9 +32,9 @@ wandb_watch:
|
|||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
batch_size: 4
|
batch_size: 16
|
||||||
micro_batch_size: 1
|
micro_batch_size: 4
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
@@ -211,12 +211,12 @@ def load_model(
|
|||||||
try:
|
try:
|
||||||
if is_llama_derived_model and "LlamaTokenizer" in globals():
|
if is_llama_derived_model and "LlamaTokenizer" in globals():
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(
|
tokenizer = LlamaTokenizer.from_pretrained(
|
||||||
model,
|
base_model_config,
|
||||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
||||||
model,
|
base_model_config,
|
||||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
|
|||||||
Reference in New Issue
Block a user