Compare commits
1 Commits
v0.5.1
...
20240307-u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b432346e3 |
@@ -741,7 +741,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
warmup_steps = None
|
||||
if self.cfg.warmup_steps is not None:
|
||||
warmup_steps = self.cfg.warmup_steps
|
||||
elif self.cfg.warmup_ratio is not None:
|
||||
|
||||
@@ -11,7 +11,7 @@ import torch
|
||||
import transformers.modelcard
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from peft import PeftModel
|
||||
from peft import PeftModel, PeftModelForCausalLM
|
||||
from pkg_resources import get_distribution # type: ignore
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
@@ -207,6 +207,20 @@ def train(
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
if cfg.adapter and isinstance(model, (PeftModel, PeftModelForCausalLM)):
|
||||
model.to("cpu")
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir) / "merged"),
|
||||
safe_serialization=safe_serialization,
|
||||
progressbar=True,
|
||||
)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
|
||||
|
||||
if not cfg.hub_model_id:
|
||||
try:
|
||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||
|
||||
@@ -114,7 +114,9 @@ def prepare_dataset(cfg, tokenizer):
|
||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||
if total_eval_steps == 0:
|
||||
raise ValueError(
|
||||
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
|
||||
"eval dataset split is too small for sample_packing. "
|
||||
"You should set `eval_sample_packing: False` "
|
||||
"or decrease the value of `eval_batch_size`. "
|
||||
)
|
||||
|
||||
if cfg.max_steps:
|
||||
|
||||
@@ -5,7 +5,7 @@ Multipack Batch Sampler
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Iterable, List, Union
|
||||
from typing import Any, Iterable, List, Union, Optional
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
@@ -115,12 +115,14 @@ class MultipackBatchSampler(BatchSampler):
|
||||
batch_max_len: int,
|
||||
lengths: np.ndarray,
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
consistent_length: Optional[bool] = False,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
self.batch_size = batch_size
|
||||
self.batch_max_len = batch_max_len
|
||||
self.lengths: np.ndarray = lengths
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
self.consistent_length = consistent_length
|
||||
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
|
||||
@@ -164,11 +166,18 @@ class MultipackBatchSampler(BatchSampler):
|
||||
|
||||
def __iter__(self):
|
||||
batches = self.generate_batches(set_stats=True)
|
||||
return iter(batches)
|
||||
if self.consistent_length:
|
||||
length = self._len_est()
|
||||
return iter(batches[:length])
|
||||
else:
|
||||
return iter(batches)
|
||||
|
||||
def num_batches(self):
|
||||
batches = self.generate_batches(set_stats=True)
|
||||
return len(batches)
|
||||
if self.consistent_length:
|
||||
return self._len_est()
|
||||
else:
|
||||
return len(batches)
|
||||
|
||||
def efficiency(self):
|
||||
return self.eff_total_used / self.eff_total_slots
|
||||
|
||||
@@ -277,7 +277,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
calc_sample_packing_eff_est,
|
||||
)
|
||||
sample_packing_eff_est = (
|
||||
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
||||
math.ceil(sample_packing_actual_eff_all * 10000.0) / 10000.0
|
||||
)
|
||||
if update:
|
||||
cfg.sample_packing_eff_est = sample_packing_eff_est
|
||||
|
||||
Reference in New Issue
Block a user