Merge pull request #123 from OpenAccess-AI-Collective/bas-batch

add support for gradient accumulation steps
This commit is contained in:
Wing Lian
2023-05-30 23:45:29 -04:00
committed by GitHub
3 changed files with 38 additions and 3 deletions

View File

@@ -149,8 +149,12 @@ def train(
else:
cfg[k] = kwargs[k]
validate_config(cfg)
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
@@ -168,8 +172,6 @@ def train(
cfg.fp16 = True
cfg.bf16 = False
validate_config(cfg)
# load the tokenizer first
logging.info("loading tokenizer...")
tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)

View File

@@ -4,6 +4,10 @@ import logging
def validate_config(cfg):
if cfg.gradient_accumulation_steps and cfg.batch_size:
raise ValueError(
"please set only one of gradient_accumulation_steps or batch_size"
)
if cfg.load_4bit:
raise ValueError(
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"

View File

@@ -117,3 +117,32 @@ class ValidationTest(unittest.TestCase):
}
)
validate_config(cfg)
def test_gradient_accumulations_or_batch_size(self):
cfg = DictDefault(
{
"gradient_accumulation_steps": 1,
"batch_size": 1,
}
)
with pytest.raises(
ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
):
validate_config(cfg)
cfg = DictDefault(
{
"batch_size": 1,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"gradient_accumulation_steps": 1,
}
)
validate_config(cfg)