refactor to set eval_batch_size earlier if unset, so we can warn if mismatched (#662)
This commit is contained in:
@@ -571,7 +571,7 @@ torch_compile_backend: # Optional[str]
|
|||||||
# training hyperparameters
|
# training hyperparameters
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
eval_batch_size: 2
|
eval_batch_size:
|
||||||
num_epochs: 3
|
num_epochs: 3
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
learning_rate: 0.00003
|
learning_rate: 0.00003
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ def normalize_config(cfg):
|
|||||||
cfg.batch_size = (
|
cfg.batch_size = (
|
||||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||||
)
|
)
|
||||||
|
if cfg.eval_batch_size is None:
|
||||||
|
cfg.eval_batch_size = cfg.micro_batch_size
|
||||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
cfg.eval_table_size = cfg.eval_table_size or 0
|
cfg.eval_table_size = cfg.eval_table_size or 0
|
||||||
@@ -157,6 +159,11 @@ def validate_config(cfg):
|
|||||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||||
)
|
)
|
||||||
|
if cfg.eval_batch_size != cfg.micro_batch_size:
|
||||||
|
LOG.warning(
|
||||||
|
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.load_4bit:
|
if cfg.load_4bit:
|
||||||
raise ValueError("cfg.load_4bit parameter has been deprecated")
|
raise ValueError("cfg.load_4bit parameter has been deprecated")
|
||||||
|
|
||||||
|
|||||||
@@ -668,9 +668,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
max_steps=total_num_steps if cfg.max_steps else -1,
|
max_steps=total_num_steps if cfg.max_steps else -1,
|
||||||
max_seq_length=cfg.sequence_len,
|
max_seq_length=cfg.sequence_len,
|
||||||
per_device_train_batch_size=cfg.micro_batch_size,
|
per_device_train_batch_size=cfg.micro_batch_size,
|
||||||
per_device_eval_batch_size=cfg.eval_batch_size
|
per_device_eval_batch_size=cfg.eval_batch_size,
|
||||||
if cfg.eval_batch_size is not None
|
|
||||||
else cfg.micro_batch_size,
|
|
||||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||||
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||||
num_train_epochs=cfg.num_epochs,
|
num_train_epochs=cfg.num_epochs,
|
||||||
|
|||||||
Reference in New Issue
Block a user