ci for fa3
This commit is contained in:
@@ -458,6 +458,7 @@ def cleanup_monkeypatches():
|
||||
("transformers.trainer",),
|
||||
("transformers", ["Trainer"]),
|
||||
("transformers.loss.loss_utils",),
|
||||
("transformers.modeling_flash_attention_utils",),
|
||||
]
|
||||
for module_name_tuple in modules_to_reset:
|
||||
module_name = module_name_tuple[0]
|
||||
|
||||
@@ -4,8 +4,8 @@ E2E tests for packed training
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
@@ -14,19 +14,22 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_tensorboard, with_temp_dir
|
||||
from .utils import check_tensorboard
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestPackedLlama(unittest.TestCase):
|
||||
class TestPackedLlama:
|
||||
"""
|
||||
Test case for Packed training of llama models
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_loss_packed(self, temp_dir):
|
||||
@pytest.mark.parametrize(
|
||||
"use_flash_attention_3",
|
||||
[False, "auto"],
|
||||
)
|
||||
def test_loss_packed(self, temp_dir, use_flash_attention_3):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -54,6 +57,7 @@ class TestPackedLlama(unittest.TestCase):
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"use_tensorboard": True,
|
||||
"use_flash_attention_3": use_flash_attention_3,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
|
||||
Reference in New Issue
Block a user