add support for trust_remote_code for mpt models

This commit is contained in:
Wing Lian
2023-05-08 12:07:27 -04:00
parent 709be5af81
commit a125693122
3 changed files with 68 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
# MPT-7B
```shell
accelerate launch scripts/finetune.py examples/mpt-7b/config.yml
```

View File

@@ -0,0 +1,59 @@
base_model: mosaicml/mpt-7b
base_model_config: mosaicml/mpt-7b
model_type: AutoModelForCausalLM
tokenizer_type: GPTNeoXTokenizer
trust_remote_code: true # required for mpt as their model class is not merged into transformers yet
load_in_8bit: false
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project: mpt-alpaca-7b
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./mpt-alpaca-7b
batch_size: 4
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0000002
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 5
xformers_attention:
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
eval_steps: 110
save_steps: 660
debug:
deepspeed:
weight_decay: 0.0001
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|padding|>"
bos_token: "<|endoftext|>"
eos_token: "<|endoftext|>"
unk_token: "<|endoftext|>"

View File

@@ -113,6 +113,7 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
else:
model = AutoModelForCausalLM.from_pretrained(
@@ -120,6 +121,7 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
except Exception as e:
logging.error(
@@ -131,6 +133,7 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
if not tokenizer: