Fix: Higher vram usage for mistral and sample_packing (#691)
* Fix: Higher vram usage for mistral and sample_packing * chore: update comment * chore: lint
This commit is contained in:
@@ -36,10 +36,10 @@ lora_target_modules:
|
|||||||
- k_proj
|
- k_proj
|
||||||
- o_proj
|
- o_proj
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -76,4 +76,4 @@ fsdp_config:
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
bos_token: "<s>"
|
bos_token: "<s>"
|
||||||
eos_token: "</s>"
|
eos_token: "</s>"
|
||||||
unk_token: "<unk>"
|
unk_token: "<unk>"
|
||||||
|
|||||||
@@ -81,7 +81,8 @@ def load_tokenizer(cfg):
|
|||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
if cfg.is_mistral_derived_model:
|
# Mistral's official FA implementation requires left padding
|
||||||
|
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
|
||||||
tokenizer.padding_side = "left"
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
if cfg.special_tokens:
|
if cfg.special_tokens:
|
||||||
|
|||||||
Reference in New Issue
Block a user