Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
3b432346e3 WIP 2024-03-07 08:30:13 -05:00
5 changed files with 31 additions and 7 deletions

View File

@@ -741,7 +741,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return AxolotlTrainer return AxolotlTrainer
def build(self, total_num_steps): def build(self, total_num_steps):
warmup_steps = None
if self.cfg.warmup_steps is not None: if self.cfg.warmup_steps is not None:
warmup_steps = self.cfg.warmup_steps warmup_steps = self.cfg.warmup_steps
elif self.cfg.warmup_ratio is not None: elif self.cfg.warmup_ratio is not None:

View File

@@ -11,7 +11,7 @@ import torch
import transformers.modelcard import transformers.modelcard
from accelerate.logging import get_logger from accelerate.logging import get_logger
from datasets import Dataset from datasets import Dataset
from peft import PeftModel from peft import PeftModel, PeftModelForCausalLM
from pkg_resources import get_distribution # type: ignore from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
@@ -207,6 +207,20 @@ def train(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if cfg.adapter and isinstance(model, (PeftModel, PeftModelForCausalLM)):
model.to("cpu")
model = model.merge_and_unload()
if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
if not cfg.hub_model_id: if not cfg.hub_model_id:
try: try:
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./")) trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))

View File

@@ -114,7 +114,9 @@ def prepare_dataset(cfg, tokenizer):
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0: if total_eval_steps == 0:
raise ValueError( raise ValueError(
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. " "eval dataset split is too small for sample_packing. "
"You should set `eval_sample_packing: False` "
"or decrease the value of `eval_batch_size`. "
) )
if cfg.max_steps: if cfg.max_steps:

View File

@@ -5,7 +5,7 @@ Multipack Batch Sampler
import logging import logging
import math import math
import os import os
from typing import Any, Iterable, List, Union from typing import Any, Iterable, List, Union, Optional
import numba import numba
import numpy as np import numpy as np
@@ -115,12 +115,14 @@ class MultipackBatchSampler(BatchSampler):
batch_max_len: int, batch_max_len: int,
lengths: np.ndarray, lengths: np.ndarray,
packing_efficiency_estimate: float = 1.0, packing_efficiency_estimate: float = 1.0,
consistent_length: Optional[bool] = False,
): ):
super().__init__(sampler, batch_size, drop_last) super().__init__(sampler, batch_size, drop_last)
self.batch_size = batch_size self.batch_size = batch_size
self.batch_max_len = batch_max_len self.batch_max_len = batch_max_len
self.lengths: np.ndarray = lengths self.lengths: np.ndarray = lengths
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.consistent_length = consistent_length
assert isinstance(self.lengths, np.ndarray) assert isinstance(self.lengths, np.ndarray)
@@ -164,11 +166,18 @@ class MultipackBatchSampler(BatchSampler):
def __iter__(self): def __iter__(self):
batches = self.generate_batches(set_stats=True) batches = self.generate_batches(set_stats=True)
return iter(batches) if self.consistent_length:
length = self._len_est()
return iter(batches[:length])
else:
return iter(batches)
def num_batches(self): def num_batches(self):
batches = self.generate_batches(set_stats=True) batches = self.generate_batches(set_stats=True)
return len(batches) if self.consistent_length:
return self._len_est()
else:
return len(batches)
def efficiency(self): def efficiency(self):
return self.eff_total_used / self.eff_total_slots return self.eff_total_used / self.eff_total_slots

View File

@@ -277,7 +277,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
calc_sample_packing_eff_est, calc_sample_packing_eff_est,
) )
sample_packing_eff_est = ( sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0 math.ceil(sample_packing_actual_eff_all * 10000.0) / 10000.0
) )
if update: if update:
cfg.sample_packing_eff_est = sample_packing_eff_est cfg.sample_packing_eff_est = sample_packing_eff_est