Merge pull request #123 from OpenAccess-AI-Collective/bas-batch
add support for gradient accumulation steps
This commit is contained in:
@@ -149,8 +149,12 @@ def train(
|
|||||||
else:
|
else:
|
||||||
cfg[k] = kwargs[k]
|
cfg[k] = kwargs[k]
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
# setup some derived config / hyperparams
|
# setup some derived config / hyperparams
|
||||||
cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
|
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||||
|
cfg.batch_size // cfg.micro_batch_size
|
||||||
|
)
|
||||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
choose_device(cfg)
|
choose_device(cfg)
|
||||||
@@ -168,8 +172,6 @@ def train(
|
|||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
cfg.bf16 = False
|
cfg.bf16 = False
|
||||||
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
logging.info("loading tokenizer...")
|
logging.info("loading tokenizer...")
|
||||||
tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
|
tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
|
||||||
|
|||||||
@@ -4,6 +4,10 @@ import logging
|
|||||||
|
|
||||||
|
|
||||||
def validate_config(cfg):
|
def validate_config(cfg):
|
||||||
|
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
"please set only one of gradient_accumulation_steps or batch_size"
|
||||||
|
)
|
||||||
if cfg.load_4bit:
|
if cfg.load_4bit:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
|
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
|
||||||
|
|||||||
@@ -117,3 +117,32 @@ class ValidationTest(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_gradient_accumulations_or_batch_size(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"batch_size": 1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"batch_size": 1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user