more fixes for 4k and optimizations
This commit is contained in:
@@ -174,10 +174,18 @@ def drop_long_seq(sample, sequence_len=2048):
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
if cfg.sample_packing:
|
||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||
train_dataset = train_dataset.filter(drop_long).map(add_position_ids)
|
||||
eval_dataset = eval_dataset.filter(drop_long).map(add_position_ids)
|
||||
train_dataset = train_dataset.filter(drop_long).map(
|
||||
add_position_ids, num_proc=os.cpu_count()
|
||||
)
|
||||
eval_dataset = eval_dataset.filter(drop_long).map(
|
||||
add_position_ids, num_proc=os.cpu_count()
|
||||
)
|
||||
if cfg.sample_packing_eff_est:
|
||||
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
|
||||
total_num_tokens = (
|
||||
cfg.total_num_tokens
|
||||
if cfg.total_num_tokens
|
||||
else sum(len(s["input_ids"]) for s in train_dataset)
|
||||
)
|
||||
total_num_steps = (
|
||||
math.ceil(
|
||||
total_num_tokens
|
||||
@@ -187,6 +195,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
)
|
||||
* cfg.num_epochs
|
||||
)
|
||||
LOG.info(
|
||||
f"total_num_tokens: {total_num_tokens}, total_num_steps: {total_num_steps}"
|
||||
)
|
||||
else:
|
||||
sampler = RandomSampler(train_dataset)
|
||||
data_loader = MultipackDistributedDataloader(
|
||||
@@ -295,6 +306,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
|
||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
||||
max_seq_length=cfg.sequence_len,
|
||||
per_device_train_batch_size=cfg.micro_batch_size,
|
||||
per_device_eval_batch_size=cfg.eval_batch_size
|
||||
if cfg.eval_batch_size is not None
|
||||
|
||||
Reference in New Issue
Block a user