From 05b398a0726563850700874e3390663e6751c5ff Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 29 Mar 2024 02:38:02 -0400 Subject: [PATCH] fix some of the edge cases for Jamba (#1452) * fix some of the edge cases for Jamba * update requirements for jamba --- .github/workflows/pypi.yml | 2 +- .github/workflows/tests.yml | 2 + examples/jamba/README.md | 11 +++-- examples/jamba/qlora.yaml | 62 ++++++++++++++++++++++++++++ requirements.txt | 2 +- setup.py | 2 +- src/axolotl/monkeypatch/multipack.py | 24 ++++++----- src/axolotl/utils/models.py | 4 ++ 8 files changed, 92 insertions(+), 17 deletions(-) create mode 100644 examples/jamba/qlora.yaml diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index dbd225f6f..885239d18 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -25,7 +25,7 @@ jobs: - name: Install dependencies run: | - pip3 install wheel + pip3 install wheel packaging pip3 install -e . pip3 install -r requirements-tests.txt diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d8ca5400a..91022c2e1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,6 +48,8 @@ jobs: - name: Install dependencies run: | + pip3 install --upgrade pip + pip3 install --upgrade packaging pip3 install -U -e . pip3 install -r requirements-tests.txt diff --git a/examples/jamba/README.md b/examples/jamba/README.md index aa98c0245..54f5d1da9 100644 --- a/examples/jamba/README.md +++ b/examples/jamba/README.md @@ -1,5 +1,10 @@ # Jamba -qlora w/ deepspeed needs at least 2x GPUs and 35GiB VRAM per GPU - -qlora single-gpu - training will start, but loss is off by an order of magnitude +- ✅ qlora w/ deepspeed Zero-2 needs at least 2x GPUs and + - 35GiB VRAM per GPU w minimal context length + - 56GiB VRAM per GPU (w multipack enabled) +- ✅ qlora w/ deepspeed Zero-3 needs at least 2x GPUs and 67GiB VRAM (wtf?) +- ✅ qlora single-gpu, ~51GiB VRAM +- ✅ multipack +- ❓ FSDP +- ❓ 8-bit LoRA diff --git a/examples/jamba/qlora.yaml b/examples/jamba/qlora.yaml new file mode 100644 index 000000000..41a3854fe --- /dev/null +++ b/examples/jamba/qlora.yaml @@ -0,0 +1,62 @@ +base_model: ai21labs/Jamba-v0.1 +trust_remote_code: true + +load_in_8bit: false +load_in_4bit: true +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: +val_set_size: 0.0 +output_dir: ./out + +sequence_len: 4096 +sample_packing: false +pad_to_sequence_len: false +eval_sample_packing: false + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +adapter: qlora +lora_r: 8 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true + +low_cpu_mem_usage: true +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 2 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 0.00001 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +special_tokens: diff --git a/requirements.txt b/requirements.txt index 75ce7a0d8..b3db07d05 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,7 +32,7 @@ fschat==0.2.36 gradio==3.50.2 tensorboard -mamba-ssm==1.1.1 +mamba-ssm==1.2.0.post1 # remote filesystems s3fs diff --git a/setup.py b/setup.py index 307691bd4..fbca5a360 100644 --- a/setup.py +++ b/setup.py @@ -78,7 +78,7 @@ setup( "deepspeed-kernels", ], "mamba-ssm": [ - "mamba-ssm==1.0.1", + "mamba-ssm==1.2.0.post1", ], "auto-gptq": [ "auto-gptq==0.5.1", diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index fbcaf7a66..a8f5e7a84 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -48,14 +48,16 @@ def patch_for_multipack(model_type, model_name=None): get_unpad_data ) elif model_type == "gemmoe": - model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - # we need to load the model here in order for modeling_gemmoe to be available - with init_empty_weights(): - AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) - module_name = model_config.__class__.__module__.replace( - ".configuration_gemmoe", ".modeling_gemmoe" - ) - modeling_gemmoe = importlib.import_module(module_name) - modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) + patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") + elif model_type == "jamba": + patch_remote(model_name, ".configuration_jamba", ".modeling_jamba") + + +def patch_remote(model_name, config_name, modeling_name): + model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + # we need to load the model here in order for modeling_* to be available + with init_empty_weights(): + AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + module_name = model_config.__class__.__module__.replace(config_name, modeling_name) + modeling_arch = importlib.import_module(module_name) + modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 911a6c31b..31686f600 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -456,6 +456,10 @@ def load_model( "bnb_4bit_quant_type": "nf4", "bnb_4bit_quant_storage": torch.bfloat16, } + if cfg.model_config_type == "jamba" and not cfg.deepspeed: + # for some reason, this causes the loss to be off by an order of magnitude + # but deepspeed needs this still in bfloat16 + bnb_config["bnb_4bit_quant_storage"] = torch.float32 if cfg.bnb_config_kwargs: bnb_config.update(cfg.bnb_config_kwargs)