Update doc for grad_accu and add validation tests for batch size
This commit is contained in:
@@ -397,6 +397,7 @@ Add below flag to train command above
|
|||||||
Please reduce any below
|
Please reduce any below
|
||||||
- `micro_batch_size`
|
- `micro_batch_size`
|
||||||
- `eval_batch_size`
|
- `eval_batch_size`
|
||||||
|
- `gradient_accumulation_steps`
|
||||||
- `sequence_len`
|
- `sequence_len`
|
||||||
|
|
||||||
> RuntimeError: expected scalar type Float but found Half
|
> RuntimeError: expected scalar type Float but found Half
|
||||||
|
|||||||
@@ -8,6 +8,12 @@ def validate_config(cfg):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"please set only one of gradient_accumulation_steps or batch_size"
|
"please set only one of gradient_accumulation_steps or batch_size"
|
||||||
)
|
)
|
||||||
|
if cfg.batch_size:
|
||||||
|
logging.warning(
|
||||||
|
"%s\n%s",
|
||||||
|
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||||
|
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||||
|
)
|
||||||
if cfg.load_4bit:
|
if cfg.load_4bit:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
|
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"""Module for testing the validation module"""
|
"""Module for testing the validation module"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -13,6 +15,12 @@ class ValidationTest(unittest.TestCase):
|
|||||||
Test the validation module
|
Test the validation module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def inject_fixtures(self, caplog):
|
||||||
|
self._caplog = caplog
|
||||||
|
|
||||||
def test_load_4bit_deprecate(self):
|
def test_load_4bit_deprecate(self):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -23,6 +31,17 @@ class ValidationTest(unittest.TestCase):
|
|||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_batch_size_unused_warning(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"batch_size": 32,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert "batch_size is not recommended" in self._caplog.records[0].message
|
||||||
|
|
||||||
def test_qlora(self):
|
def test_qlora(self):
|
||||||
base_cfg = DictDefault(
|
base_cfg = DictDefault(
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user