support for mamba (#915)

* support for mamba

* more mamba fixes

* use fork for mamba kwargs fix

* grad checkpointing doesn't work

* fix extras for mamaba

* mamba loss fix

* use fp32 and remove verbose logging

* mamba fixes

* fix collator for mamba

* set model_type on training_args

* don't save safetensors for mamba

* update mamba config to disable safetensor checkpooints, install for tests

* no evals for mamba tests

* handle save_pretrained

* handle unused safetensors arg
This commit is contained in:
Wing Lian
2023-12-09 12:10:41 -05:00
committed by GitHub
parent d339beb9d9
commit 40a6362c92
12 changed files with 447 additions and 24 deletions

View File

@@ -82,7 +82,8 @@ def train(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
model.config.use_cache = False
if hasattr(model, "config"):
model.config.use_cache = False
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
@@ -92,7 +93,8 @@ def train(
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
model.config.save_pretrained(str(Path(cfg.output_dir)))
if hasattr(model, "config"):
model.config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0: