more fixes for 4k and optimizations
This commit is contained in:
@@ -174,10 +174,18 @@ def drop_long_seq(sample, sequence_len=2048):
|
|||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||||
train_dataset = train_dataset.filter(drop_long).map(add_position_ids)
|
train_dataset = train_dataset.filter(drop_long).map(
|
||||||
eval_dataset = eval_dataset.filter(drop_long).map(add_position_ids)
|
add_position_ids, num_proc=os.cpu_count()
|
||||||
|
)
|
||||||
|
eval_dataset = eval_dataset.filter(drop_long).map(
|
||||||
|
add_position_ids, num_proc=os.cpu_count()
|
||||||
|
)
|
||||||
if cfg.sample_packing_eff_est:
|
if cfg.sample_packing_eff_est:
|
||||||
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
|
total_num_tokens = (
|
||||||
|
cfg.total_num_tokens
|
||||||
|
if cfg.total_num_tokens
|
||||||
|
else sum(len(s["input_ids"]) for s in train_dataset)
|
||||||
|
)
|
||||||
total_num_steps = (
|
total_num_steps = (
|
||||||
math.ceil(
|
math.ceil(
|
||||||
total_num_tokens
|
total_num_tokens
|
||||||
@@ -187,6 +195,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
)
|
)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
)
|
)
|
||||||
|
LOG.info(
|
||||||
|
f"total_num_tokens: {total_num_tokens}, total_num_steps: {total_num_steps}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
sampler = RandomSampler(train_dataset)
|
sampler = RandomSampler(train_dataset)
|
||||||
data_loader = MultipackDistributedDataloader(
|
data_loader = MultipackDistributedDataloader(
|
||||||
@@ -295,6 +306,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
|
|
||||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
||||||
|
max_seq_length=cfg.sequence_len,
|
||||||
per_device_train_batch_size=cfg.micro_batch_size,
|
per_device_train_batch_size=cfg.micro_batch_size,
|
||||||
per_device_eval_batch_size=cfg.eval_batch_size
|
per_device_eval_batch_size=cfg.eval_batch_size
|
||||||
if cfg.eval_batch_size is not None
|
if cfg.eval_batch_size is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user