support for auto_find_batch_size when packing (#1885)
* support for auto_find_batch_size when packing * make sure to return data from validation * make sure to return data from validation * actually expose multipack_real_batches in the config * calculate gathered efficiency in sampler * tweak to fix auto find and use actual sampler len for multipack * uncomment * use args for bsz when not available from auto find
This commit is contained in:
@@ -506,9 +506,10 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
batch_max_len = (
|
||||
self.args.per_device_train_batch_size * self.args.max_seq_length
|
||||
train_batch_size = (
|
||||
self.state.train_batch_size or self.args.per_device_train_batch_size
|
||||
)
|
||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||
return MultipackBatchSampler(
|
||||
RandomSampler(self.train_dataset),
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
@@ -1379,6 +1380,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs[
|
||||
"per_device_eval_batch_size"
|
||||
] = self.cfg.eval_batch_size
|
||||
if self.cfg.auto_find_batch_size is not None:
|
||||
training_arguments_kwargs[
|
||||
"auto_find_batch_size"
|
||||
] = self.cfg.auto_find_batch_size
|
||||
training_arguments_kwargs[
|
||||
"gradient_accumulation_steps"
|
||||
] = self.cfg.gradient_accumulation_steps
|
||||
@@ -1461,9 +1466,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
|
||||
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||
training_arguments_kwargs[
|
||||
"multipack_real_batches"
|
||||
] = not self.cfg.flash_attention
|
||||
training_arguments_kwargs["multipack_real_batches"] = (
|
||||
not self.cfg.flash_attention or self.cfg.multipack_real_batches
|
||||
)
|
||||
training_arguments_kwargs["eval_sample_packing"] = bool(
|
||||
self.cfg.eval_sample_packing
|
||||
)
|
||||
|
||||
@@ -355,6 +355,8 @@ class HyperparametersConfig(BaseModel):
|
||||
},
|
||||
)
|
||||
|
||||
auto_find_batch_size: Optional[bool] = None
|
||||
|
||||
train_on_inputs: Optional[bool] = False
|
||||
group_by_length: Optional[bool] = None
|
||||
|
||||
@@ -592,6 +594,7 @@ class AxolotlInputConfig(
|
||||
eval_sample_packing: Optional[bool] = None
|
||||
pad_to_sequence_len: Optional[bool] = None
|
||||
curriculum_sampling: Optional[bool] = None
|
||||
multipack_real_batches: Optional[bool] = None
|
||||
|
||||
# for PoSE context length extension
|
||||
use_pose: Optional[bool] = None
|
||||
|
||||
@@ -11,6 +11,8 @@ import numba
|
||||
import numpy as np
|
||||
from torch.utils.data import BatchSampler, Sampler
|
||||
|
||||
from axolotl.utils.distributed import reduce_and_broadcast
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
|
||||
|
||||
|
||||
@@ -174,16 +176,46 @@ class MultipackBatchSampler(BatchSampler):
|
||||
def efficiency(self):
|
||||
return self.eff_total_used / self.eff_total_slots
|
||||
|
||||
def gather_efficiency(self):
|
||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||
return math.floor(0.997 * max(estimates))
|
||||
|
||||
sample_packing_actual_eff_all = reduce_and_broadcast(
|
||||
lambda: self.efficiency(), # pylint: disable=unnecessary-lambda
|
||||
calc_sample_packing_eff_est,
|
||||
)
|
||||
sample_packing_eff_est = (
|
||||
math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0
|
||||
)
|
||||
return sample_packing_eff_est
|
||||
|
||||
def gather_len_batches(self, num):
|
||||
def calc_min_len(estimates: list[(int, float)]):
|
||||
LOG.info(f"gather_len_batches: {repr(estimates)}")
|
||||
return math.floor(0.998 * min(estimates))
|
||||
|
||||
min_len_batches = reduce_and_broadcast(
|
||||
lambda: num,
|
||||
calc_min_len,
|
||||
)
|
||||
return min_len_batches
|
||||
|
||||
def __len__(self):
|
||||
self.num_batches()
|
||||
return self._len_est()
|
||||
len_batches = self.num_batches()
|
||||
return self.gather_len_batches(len_batches)
|
||||
|
||||
def _len_est(self):
|
||||
efficiency = (
|
||||
self.packing_efficiency_estimate
|
||||
if self.packing_efficiency_estimate
|
||||
else self.gather_efficiency()
|
||||
)
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
lengths_sum = np.sum(self.lengths)
|
||||
lengths_sum_per_device = lengths_sum // world_size
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
f"packing_efficiency_estimate: {efficiency} "
|
||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||
)
|
||||
|
||||
@@ -195,7 +227,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
* math.floor(
|
||||
0.99
|
||||
* lengths_sum_per_device
|
||||
/ self.packing_efficiency_estimate
|
||||
/ efficiency
|
||||
// (self.batch_max_len * self.batch_size)
|
||||
)
|
||||
- 1
|
||||
|
||||
@@ -357,7 +357,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
main_process_only=True,
|
||||
)
|
||||
else:
|
||||
if cfg.flash_attention:
|
||||
if cfg.flash_attention and not cfg.multipack_real_batches:
|
||||
sampler_batch_size = 1
|
||||
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user