support for auto_find_batch_size when packing (#1885)

* support for auto_find_batch_size when packing

* make sure to return data from validation

* make sure to return data from validation

* actually expose multipack_real_batches in the config

* calculate gathered efficiency in sampler

* tweak to fix auto find and use actual sampler len for multipack

* uncomment

* use args for bsz when not available from auto find
This commit is contained in:
Wing Lian
2024-09-03 20:02:44 -04:00
committed by GitHub
parent 0aeb277456
commit 4e5400c732
4 changed files with 50 additions and 10 deletions

View File

@@ -506,9 +506,10 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
batch_max_len = (
self.args.per_device_train_batch_size * self.args.max_seq_length
train_batch_size = (
self.state.train_batch_size or self.args.per_device_train_batch_size
)
batch_max_len = train_batch_size * self.args.max_seq_length
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
lengths=get_dataset_lengths(self.train_dataset),
@@ -1379,6 +1380,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
if self.cfg.auto_find_batch_size is not None:
training_arguments_kwargs[
"auto_find_batch_size"
] = self.cfg.auto_find_batch_size
training_arguments_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
@@ -1461,9 +1466,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
training_arguments_kwargs[
"multipack_real_batches"
] = not self.cfg.flash_attention
training_arguments_kwargs["multipack_real_batches"] = (
not self.cfg.flash_attention or self.cfg.multipack_real_batches
)
training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing
)

View File

@@ -355,6 +355,8 @@ class HyperparametersConfig(BaseModel):
},
)
auto_find_batch_size: Optional[bool] = None
train_on_inputs: Optional[bool] = False
group_by_length: Optional[bool] = None
@@ -592,6 +594,7 @@ class AxolotlInputConfig(
eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None
multipack_real_batches: Optional[bool] = None
# for PoSE context length extension
use_pose: Optional[bool] = None

View File

@@ -11,6 +11,8 @@ import numba
import numpy as np
from torch.utils.data import BatchSampler, Sampler
from axolotl.utils.distributed import reduce_and_broadcast
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
@@ -174,16 +176,46 @@ class MultipackBatchSampler(BatchSampler):
def efficiency(self):
return self.eff_total_used / self.eff_total_slots
def gather_efficiency(self):
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
return math.floor(0.997 * max(estimates))
sample_packing_actual_eff_all = reduce_and_broadcast(
lambda: self.efficiency(), # pylint: disable=unnecessary-lambda
calc_sample_packing_eff_est,
)
sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0
)
return sample_packing_eff_est
def gather_len_batches(self, num):
def calc_min_len(estimates: list[(int, float)]):
LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(0.998 * min(estimates))
min_len_batches = reduce_and_broadcast(
lambda: num,
calc_min_len,
)
return min_len_batches
def __len__(self):
self.num_batches()
return self._len_est()
len_batches = self.num_batches()
return self.gather_len_batches(len_batches)
def _len_est(self):
efficiency = (
self.packing_efficiency_estimate
if self.packing_efficiency_estimate
else self.gather_efficiency()
)
world_size = int(os.getenv("WORLD_SIZE", "1"))
lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // world_size
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"packing_efficiency_estimate: {efficiency} "
f"total_num_tokens per device: {lengths_sum_per_device}"
)
@@ -195,7 +227,7 @@ class MultipackBatchSampler(BatchSampler):
* math.floor(
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
/ efficiency
// (self.batch_max_len * self.batch_size)
)
- 1

View File

@@ -357,7 +357,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
main_process_only=True,
)
else:
if cfg.flash_attention:
if cfg.flash_attention and not cfg.multipack_real_batches:
sampler_batch_size = 1
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
else: