support for mamba (#915)
* support for mamba * more mamba fixes * use fork for mamba kwargs fix * grad checkpointing doesn't work * fix extras for mamaba * mamba loss fix * use fp32 and remove verbose logging * mamba fixes * fix collator for mamba * set model_type on training_args * don't save safetensors for mamba * update mamba config to disable safetensor checkpooints, install for tests * no evals for mamba tests * handle save_pretrained * handle unused safetensors arg
This commit is contained in:
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -73,7 +73,7 @@ jobs:
|
||||
run: |
|
||||
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
||||
pip3 uninstall -y transformers accelerate
|
||||
pip3 install -U -e .[flash-attn]
|
||||
pip3 install -U -e .[flash-attn,mamba-ssm]
|
||||
pip3 install -r requirements-tests.txt
|
||||
|
||||
- name: Run e2e tests
|
||||
|
||||
Reference in New Issue
Block a user