Fix: Higher vram usage for mistral and sample_packing (#691)

* Fix: Higher vram usage for mistral and sample_packing

* chore: update comment

* chore: lint
This commit is contained in:
NanoCode012
2023-10-07 01:33:43 +09:00
committed by GitHub
parent d4a88e4eca
commit 669f1d052c
2 changed files with 6 additions and 5 deletions

View File

@@ -36,10 +36,10 @@ lora_target_modules:
- k_proj - k_proj
- o_proj - o_proj
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
@@ -76,4 +76,4 @@ fsdp_config:
special_tokens: special_tokens:
bos_token: "<s>" bos_token: "<s>"
eos_token: "</s>" eos_token: "</s>"
unk_token: "<unk>" unk_token: "<unk>"

View File

@@ -81,7 +81,8 @@ def load_tokenizer(cfg):
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.is_mistral_derived_model: # Mistral's official FA implementation requires left padding
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
if cfg.special_tokens: if cfg.special_tokens: