support for auto_find_batch_size when packing (#1885)

* support for auto_find_batch_size when packing

* make sure to return data from validation

* make sure to return data from validation

* actually expose multipack_real_batches in the config

* calculate gathered efficiency in sampler

* tweak to fix auto find and use actual sampler len for multipack

* uncomment

* use args for bsz when not available from auto find
This commit is contained in:
Wing Lian
2024-09-03 20:02:44 -04:00
committed by GitHub
parent 0aeb277456
commit 4e5400c732
4 changed files with 50 additions and 10 deletions

View File

@@ -506,9 +506,10 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
batch_max_len = self.args.max_seq_length batch_max_len = self.args.max_seq_length
else: else:
batch_size = 1 batch_size = 1
batch_max_len = ( train_batch_size = (
self.args.per_device_train_batch_size * self.args.max_seq_length self.state.train_batch_size or self.args.per_device_train_batch_size
) )
batch_max_len = train_batch_size * self.args.max_seq_length
return MultipackBatchSampler( return MultipackBatchSampler(
RandomSampler(self.train_dataset), RandomSampler(self.train_dataset),
lengths=get_dataset_lengths(self.train_dataset), lengths=get_dataset_lengths(self.train_dataset),
@@ -1379,6 +1380,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[ training_arguments_kwargs[
"per_device_eval_batch_size" "per_device_eval_batch_size"
] = self.cfg.eval_batch_size ] = self.cfg.eval_batch_size
if self.cfg.auto_find_batch_size is not None:
training_arguments_kwargs[
"auto_find_batch_size"
] = self.cfg.auto_find_batch_size
training_arguments_kwargs[ training_arguments_kwargs[
"gradient_accumulation_steps" "gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps ] = self.cfg.gradient_accumulation_steps
@@ -1461,9 +1466,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
training_arguments_kwargs[ training_arguments_kwargs["multipack_real_batches"] = (
"multipack_real_batches" not self.cfg.flash_attention or self.cfg.multipack_real_batches
] = not self.cfg.flash_attention )
training_arguments_kwargs["eval_sample_packing"] = bool( training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing self.cfg.eval_sample_packing
) )

View File

@@ -355,6 +355,8 @@ class HyperparametersConfig(BaseModel):
}, },
) )
auto_find_batch_size: Optional[bool] = None
train_on_inputs: Optional[bool] = False train_on_inputs: Optional[bool] = False
group_by_length: Optional[bool] = None group_by_length: Optional[bool] = None
@@ -592,6 +594,7 @@ class AxolotlInputConfig(
eval_sample_packing: Optional[bool] = None eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None curriculum_sampling: Optional[bool] = None
multipack_real_batches: Optional[bool] = None
# for PoSE context length extension # for PoSE context length extension
use_pose: Optional[bool] = None use_pose: Optional[bool] = None

View File

@@ -11,6 +11,8 @@ import numba
import numpy as np import numpy as np
from torch.utils.data import BatchSampler, Sampler from torch.utils.data import BatchSampler, Sampler
from axolotl.utils.distributed import reduce_and_broadcast
LOG = logging.getLogger("axolotl.utils.samplers.multipack") LOG = logging.getLogger("axolotl.utils.samplers.multipack")
@@ -174,16 +176,46 @@ class MultipackBatchSampler(BatchSampler):
def efficiency(self): def efficiency(self):
return self.eff_total_used / self.eff_total_slots return self.eff_total_used / self.eff_total_slots
def gather_efficiency(self):
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
return math.floor(0.997 * max(estimates))
sample_packing_actual_eff_all = reduce_and_broadcast(
lambda: self.efficiency(), # pylint: disable=unnecessary-lambda
calc_sample_packing_eff_est,
)
sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0
)
return sample_packing_eff_est
def gather_len_batches(self, num):
def calc_min_len(estimates: list[(int, float)]):
LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(0.998 * min(estimates))
min_len_batches = reduce_and_broadcast(
lambda: num,
calc_min_len,
)
return min_len_batches
def __len__(self): def __len__(self):
self.num_batches() len_batches = self.num_batches()
return self._len_est() return self.gather_len_batches(len_batches)
def _len_est(self): def _len_est(self):
efficiency = (
self.packing_efficiency_estimate
if self.packing_efficiency_estimate
else self.gather_efficiency()
)
world_size = int(os.getenv("WORLD_SIZE", "1")) world_size = int(os.getenv("WORLD_SIZE", "1"))
lengths_sum = np.sum(self.lengths) lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // world_size lengths_sum_per_device = lengths_sum // world_size
LOG.info( LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " f"packing_efficiency_estimate: {efficiency} "
f"total_num_tokens per device: {lengths_sum_per_device}" f"total_num_tokens per device: {lengths_sum_per_device}"
) )
@@ -195,7 +227,7 @@ class MultipackBatchSampler(BatchSampler):
* math.floor( * math.floor(
0.99 0.99
* lengths_sum_per_device * lengths_sum_per_device
/ self.packing_efficiency_estimate / efficiency
// (self.batch_max_len * self.batch_size) // (self.batch_max_len * self.batch_size)
) )
- 1 - 1

View File

@@ -357,7 +357,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
main_process_only=True, main_process_only=True,
) )
else: else:
if cfg.flash_attention: if cfg.flash_attention and not cfg.multipack_real_batches:
sampler_batch_size = 1 sampler_batch_size = 1
batch_max_len = cfg.micro_batch_size * cfg.sequence_len batch_max_len = cfg.micro_batch_size * cfg.sequence_len
else: else: