support for auto_find_batch_size when packing (#1885)
* support for auto_find_batch_size when packing * make sure to return data from validation * make sure to return data from validation * actually expose multipack_real_batches in the config * calculate gathered efficiency in sampler * tweak to fix auto find and use actual sampler len for multipack * uncomment * use args for bsz when not available from auto find
This commit is contained in:
@@ -506,9 +506,10 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
batch_max_len = self.args.max_seq_length
|
batch_max_len = self.args.max_seq_length
|
||||||
else:
|
else:
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
batch_max_len = (
|
train_batch_size = (
|
||||||
self.args.per_device_train_batch_size * self.args.max_seq_length
|
self.state.train_batch_size or self.args.per_device_train_batch_size
|
||||||
)
|
)
|
||||||
|
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||||
return MultipackBatchSampler(
|
return MultipackBatchSampler(
|
||||||
RandomSampler(self.train_dataset),
|
RandomSampler(self.train_dataset),
|
||||||
lengths=get_dataset_lengths(self.train_dataset),
|
lengths=get_dataset_lengths(self.train_dataset),
|
||||||
@@ -1379,6 +1380,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"per_device_eval_batch_size"
|
"per_device_eval_batch_size"
|
||||||
] = self.cfg.eval_batch_size
|
] = self.cfg.eval_batch_size
|
||||||
|
if self.cfg.auto_find_batch_size is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"auto_find_batch_size"
|
||||||
|
] = self.cfg.auto_find_batch_size
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"gradient_accumulation_steps"
|
"gradient_accumulation_steps"
|
||||||
] = self.cfg.gradient_accumulation_steps
|
] = self.cfg.gradient_accumulation_steps
|
||||||
@@ -1461,9 +1466,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["multipack_real_batches"] = (
|
||||||
"multipack_real_batches"
|
not self.cfg.flash_attention or self.cfg.multipack_real_batches
|
||||||
] = not self.cfg.flash_attention
|
)
|
||||||
training_arguments_kwargs["eval_sample_packing"] = bool(
|
training_arguments_kwargs["eval_sample_packing"] = bool(
|
||||||
self.cfg.eval_sample_packing
|
self.cfg.eval_sample_packing
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -355,6 +355,8 @@ class HyperparametersConfig(BaseModel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
auto_find_batch_size: Optional[bool] = None
|
||||||
|
|
||||||
train_on_inputs: Optional[bool] = False
|
train_on_inputs: Optional[bool] = False
|
||||||
group_by_length: Optional[bool] = None
|
group_by_length: Optional[bool] = None
|
||||||
|
|
||||||
@@ -592,6 +594,7 @@ class AxolotlInputConfig(
|
|||||||
eval_sample_packing: Optional[bool] = None
|
eval_sample_packing: Optional[bool] = None
|
||||||
pad_to_sequence_len: Optional[bool] = None
|
pad_to_sequence_len: Optional[bool] = None
|
||||||
curriculum_sampling: Optional[bool] = None
|
curriculum_sampling: Optional[bool] = None
|
||||||
|
multipack_real_batches: Optional[bool] = None
|
||||||
|
|
||||||
# for PoSE context length extension
|
# for PoSE context length extension
|
||||||
use_pose: Optional[bool] = None
|
use_pose: Optional[bool] = None
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import numba
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import BatchSampler, Sampler
|
from torch.utils.data import BatchSampler, Sampler
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import reduce_and_broadcast
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
|
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
|
||||||
|
|
||||||
|
|
||||||
@@ -174,16 +176,46 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
def efficiency(self):
|
def efficiency(self):
|
||||||
return self.eff_total_used / self.eff_total_slots
|
return self.eff_total_used / self.eff_total_slots
|
||||||
|
|
||||||
|
def gather_efficiency(self):
|
||||||
|
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||||
|
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||||
|
return math.floor(0.997 * max(estimates))
|
||||||
|
|
||||||
|
sample_packing_actual_eff_all = reduce_and_broadcast(
|
||||||
|
lambda: self.efficiency(), # pylint: disable=unnecessary-lambda
|
||||||
|
calc_sample_packing_eff_est,
|
||||||
|
)
|
||||||
|
sample_packing_eff_est = (
|
||||||
|
math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0
|
||||||
|
)
|
||||||
|
return sample_packing_eff_est
|
||||||
|
|
||||||
|
def gather_len_batches(self, num):
|
||||||
|
def calc_min_len(estimates: list[(int, float)]):
|
||||||
|
LOG.info(f"gather_len_batches: {repr(estimates)}")
|
||||||
|
return math.floor(0.998 * min(estimates))
|
||||||
|
|
||||||
|
min_len_batches = reduce_and_broadcast(
|
||||||
|
lambda: num,
|
||||||
|
calc_min_len,
|
||||||
|
)
|
||||||
|
return min_len_batches
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
self.num_batches()
|
len_batches = self.num_batches()
|
||||||
return self._len_est()
|
return self.gather_len_batches(len_batches)
|
||||||
|
|
||||||
def _len_est(self):
|
def _len_est(self):
|
||||||
|
efficiency = (
|
||||||
|
self.packing_efficiency_estimate
|
||||||
|
if self.packing_efficiency_estimate
|
||||||
|
else self.gather_efficiency()
|
||||||
|
)
|
||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
lengths_sum = np.sum(self.lengths)
|
lengths_sum = np.sum(self.lengths)
|
||||||
lengths_sum_per_device = lengths_sum // world_size
|
lengths_sum_per_device = lengths_sum // world_size
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
f"packing_efficiency_estimate: {efficiency} "
|
||||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -195,7 +227,7 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
* math.floor(
|
* math.floor(
|
||||||
0.99
|
0.99
|
||||||
* lengths_sum_per_device
|
* lengths_sum_per_device
|
||||||
/ self.packing_efficiency_estimate
|
/ efficiency
|
||||||
// (self.batch_max_len * self.batch_size)
|
// (self.batch_max_len * self.batch_size)
|
||||||
)
|
)
|
||||||
- 1
|
- 1
|
||||||
|
|||||||
@@ -357,7 +357,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
main_process_only=True,
|
main_process_only=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if cfg.flash_attention:
|
if cfg.flash_attention and not cfg.multipack_real_batches:
|
||||||
sampler_batch_size = 1
|
sampler_batch_size = 1
|
||||||
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
|
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user