From 1e4366070179f238a40b7ef8356be744350c1a38 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 27 Aug 2024 13:39:24 -0400 Subject: [PATCH] Sample pack trust remote code v2 (#1873) * fix the multipack patch for remote code models * add deepseek v2 lite example w fsdp --- examples/deepseek-v2/fft-fsdp-16b.yaml | 67 ++++++++++++++++++++++++++ src/axolotl/monkeypatch/multipack.py | 2 + src/axolotl/monkeypatch/utils.py | 2 - 3 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 examples/deepseek-v2/fft-fsdp-16b.yaml diff --git a/examples/deepseek-v2/fft-fsdp-16b.yaml b/examples/deepseek-v2/fft-fsdp-16b.yaml new file mode 100644 index 000000000..b55646df7 --- /dev/null +++ b/examples/deepseek-v2/fft-fsdp-16b.yaml @@ -0,0 +1,67 @@ +base_model: deepseek-ai/DeepSeek-V2-Lite +trust_remote_code: true + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 8 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 2e-5 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 100 +evals_per_epoch: 2 +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +special_tokens: +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: true + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 44fc4cb47..529c42a8f 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -94,3 +94,5 @@ def patch_remote(model_name, config_name, modeling_name): module_name = model_config.__class__.__module__.replace(config_name, modeling_name) modeling_arch = importlib.import_module(module_name) modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access + # workaround to make the patch stick + modeling_arch._axolotl_multipack_patch = True # pylint: disable=protected-access diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index e43c58650..f29f21be7 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -17,11 +17,9 @@ def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor: max_num = int(torch.max(attention_mask).item()) batch_size, _ = attention_mask.shape counts = torch.zeros((batch_size, max_num), dtype=torch.int32) - for i in range(1, max_num + 1): mask = attention_mask == i counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32) - result = counts.flatten() nonzero_indices = torch.nonzero(result).squeeze(-1) return result[nonzero_indices]