Feat: Add Qwen (#894)

* Feat: Add Qwen

* feat: add qwen lora example

* feat: update matrix

* fix: add trust_remote_code

* fix: disable gradient checkpointing

* chore: add warning about gradient checkpointing

* fix: config

* fix: turn off sample packing for this example and reduce seq len

* chore: add comment on seq len
This commit is contained in:
NanoCode012
2023-11-26 00:05:01 +09:00
committed by GitHub
parent 7ee3c4cacb
commit 1115c501b8
5 changed files with 168 additions and 0 deletions

View File

@@ -122,6 +122,19 @@ def normalize_config(cfg):
or (cfg.model_type and "mistral" in cfg.model_type.lower())
)
cfg.is_qwen_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type
in [
"qwen",
]
)
or cfg.is_qwen_derived_model
or "qwen" in cfg.base_model.lower()
or (cfg.model_type and "qwen" in cfg.model_type.lower())
)
if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate)
@@ -379,6 +392,11 @@ def validate_config(cfg):
if cfg.warmup_steps and cfg.warmup_ratio:
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
if cfg.is_qwen_derived_model and cfg.gradient_checkpointing:
LOG.warning(
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -84,6 +84,18 @@ def load_tokenizer(cfg):
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
tokenizer.padding_side = "left"
# Qwen base only has single token, so we need to set the special tokens
if cfg.is_qwen_derived_model:
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
for attr_name in token_ids:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, tokenizer.eod_id)
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
for attr_name in token_names:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
if cfg.special_tokens:
for k, val in cfg.special_tokens.items():
tokenizer.add_special_tokens(