llama test
This commit is contained in:
@@ -387,13 +387,12 @@ class ModelLoader:
|
||||
self.patch_attention()
|
||||
|
||||
if self.cfg.model_config_type == "llama":
|
||||
from axolotl.monkeypatch.trainer_grad_accum import (
|
||||
patch_flash_attention_forward,
|
||||
from axolotl.monkeypatch.trainer_grad_accum import ( # patch_flash_attention_forward,
|
||||
patch_forward_for_ga,
|
||||
patch_training_step_for_ga,
|
||||
)
|
||||
|
||||
patch_flash_attention_forward()
|
||||
# patch_flash_attention_forward()
|
||||
patch_forward_for_ga()
|
||||
patch_training_step_for_ga()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user