support for mamba (#915)
* support for mamba * more mamba fixes * use fork for mamba kwargs fix * grad checkpointing doesn't work * fix extras for mamaba * mamba loss fix * use fp32 and remove verbose logging * mamba fixes * fix collator for mamba * set model_type on training_args * don't save safetensors for mamba * update mamba config to disable safetensor checkpooints, install for tests * no evals for mamba tests * handle save_pretrained * handle unused safetensors arg
This commit is contained in:
@@ -82,7 +82,8 @@ def train(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||
)
|
||||
|
||||
model.config.use_cache = False
|
||||
if hasattr(model, "config"):
|
||||
model.config.use_cache = False
|
||||
|
||||
# go ahead and presave, so we have the adapter config available to inspect
|
||||
if peft_config:
|
||||
@@ -92,7 +93,8 @@ def train(
|
||||
if not Path(cfg.output_dir).is_dir():
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||
if hasattr(model, "config"):
|
||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||
|
||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||
if cfg.local_rank == 0:
|
||||
|
||||
Reference in New Issue
Block a user