Update doc for grad_accu and add validation tests for batch size
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
"""Module for testing the validation module"""
|
||||
|
||||
import logging
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -13,6 +15,12 @@ class ValidationTest(unittest.TestCase):
|
||||
Test the validation module
|
||||
"""
|
||||
|
||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
self._caplog = caplog
|
||||
|
||||
def test_load_4bit_deprecate(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -23,6 +31,17 @@ class ValidationTest(unittest.TestCase):
|
||||
with pytest.raises(ValueError):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_batch_size_unused_warning(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"batch_size": 32,
|
||||
}
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert "batch_size is not recommended" in self._caplog.records[0].message
|
||||
|
||||
def test_qlora(self):
|
||||
base_cfg = DictDefault(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user